BEiT

Overview

The BEiT model was proposed in BEiT: BERT Pre-Training of Image Transformers by Hangbo Bao, Li Dong and Furu Wei. Inspired by BERT, BEiT is the first paper that makes self-supervised pre-training of Vision Transformers (ViTs) outperform supervised pre-training. Rather than pre-training the model to predict the class of an image (as done in the original ViT paper), BEiT models are pre-trained to predict visual tokens from the codebook of OpenAI’s DALL-E model given masked patches.

The abstract from the paper is the following:

We introduce a self-supervised vision representation model BEiT, which stands for Bidirectional Encoder representation from Image Transformers. Following BERT developed in the natural language processing area, we propose a masked image modeling task to pretrain vision Transformers. Specifically, each image has two views in our pre-training, i.e, image patches (such as 16x16 pixels), and visual tokens (i.e., discrete tokens). We first “tokenize” the original image into visual tokens. Then we randomly mask some image patches and fed them into the backbone Transformer. The pre-training objective is to recover the original visual tokens based on the corrupted image patches. After pre-training BEiT, we directly fine-tune the model parameters on downstream tasks by appending task layers upon the pretrained encoder. Experimental results on image classification and semantic segmentation show that our model achieves competitive results with previous pre-training methods. For example, base-size BEiT achieves 83.2% top-1 accuracy on ImageNet-1K, significantly outperforming from-scratch DeiT training (81.8%) with the same setup. Moreover, large-size BEiT obtains 86.3% only using ImageNet-1K, even outperforming ViT-L with supervised pre-training on ImageNet-22K (85.2%).

Tips:

  • BEiT models are regular Vision Transformers, but pre-trained in a self-supervised way rather than supervised. They outperform both the original model (ViT) as well as Data-efficient Image Transformers (DeiT) when fine-tuned on ImageNet-1K and CIFAR-100.

  • As the BEiT models expect each image to be of the same size (resolution), one can use BeitFeatureExtractor to resize (or rescale) and normalize images for the model.

  • Both the patch resolution and image resolution used during pre-training or fine-tuning are reflected in the name of each checkpoint. For example, microsoft/beit-base-patch16-224 refers to a base-sized architecture with patch resolution of 16x16 and fine-tuning resolution of 224x224. All checkpoints can be found on the hub.

  • The available checkpoints are either (1) pre-trained on ImageNet-22k (a collection of 14 million images and 22k classes) only, (2) also fine-tuned on ImageNet-22k or (3) also fine-tuned on ImageNet-1k (also referred to as ILSVRC 2012, a collection of 1.3 million images and 1,000 classes).

  • BEiT uses relative position embeddings, inspired by the T5 model. During pre-training, the authors shared the relative position bias among the several self-attention layers. During fine-tuning, each layer’s relative position bias is initialized with the shared relative position bias obtained after pre-training. Note that, if one wants to pre-train a model from scratch, one needs to either set the use_relative_position_bias or the use_relative_position_bias attribute of BeitConfig to True in order to add position embeddings.

This model was contributed by nielsr. The original code can be found here.

BeitConfig

class transformers.BeitConfig(vocab_size=8192, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act='gelu', hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, initializer_range=0.02, layer_norm_eps=1e-12, is_encoder_decoder=False, image_size=224, patch_size=16, num_channels=3, use_mask_token=False, use_absolute_position_embeddings=False, use_relative_position_bias=False, use_shared_relative_position_bias=False, layer_scale_init_value=0.1, drop_path_rate=0.1, use_mean_pooling=True, **kwargs)[source]

This is the configuration class to store the configuration of a BeitModel. It is used to instantiate an BEiT model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the BEiT microsoft/beit-base-patch16-224-in22k architecture.

Parameters
  • vocab_size (int, optional, defaults to 8092) – Vocabulary size of the BEiT model. Defines the number of different image tokens that can be used during pre-training.

  • hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.

  • num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.

  • intermediate_size (int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • hidden_act (str or function, optional, defaults to "gelu") – The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu", "relu", "selu" and "gelu_new" are supported.

  • hidden_dropout_prob (float, optional, defaults to 0.1) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • attention_probs_dropout_prob (float, optional, defaults to 0.1) – The dropout ratio for the attention probabilities.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • layer_norm_eps (float, optional, defaults to 1e-12) – The epsilon used by the layer normalization layers.

  • gradient_checkpointing (bool, optional, defaults to False) – If True, use gradient checkpointing to save memory at the expense of slower backward pass.

  • image_size (int, optional, defaults to 224) – The size (resolution) of each image.

  • patch_size (int, optional, defaults to 16) – The size (resolution) of each patch.

  • num_channels (int, optional, defaults to 3) – The number of input channels.

  • use_mask_token (bool, optional, defaults to False) – Whether to use a mask token for masked image modeling.

  • use_absolute_position_embeddings (bool, optional, defaults to False) – Whether to use BERT-style absolute position embeddings.

  • use_relative_position_bias (bool, optional, defaults to False) – Whether to use T5-style relative position embeddings in the self-attention layers.

  • use_shared_relative_position_bias (bool, optional, defaults to False) – Whether to use the same relative position embeddings across all self-attention layers of the Transformer.

  • layer_scale_init_value (float, optional, defaults to 0.1) – Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale.

  • drop_path_rate (float, optional, defaults to 0.1) – Stochastic depth rate per sample (when applied in the main path of residual layers).

  • use_mean_pooling (bool, optional, defaults to True) – Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the CLS token, before applying the classification head.

Example:

>>> from transformers import BeitModel, BeitConfig

>>> # Initializing a BEiT beit-base-patch16-224-in22k style configuration
>>> configuration = BeitConfig()

>>> # Initializing a model from the beit-base-patch16-224-in22k style configuration
>>> model = BeitModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config

BeitFeatureExtractor

class transformers.BeitFeatureExtractor(do_resize=True, size=256, resample=3, do_center_crop=True, crop_size=224, do_normalize=True, image_mean=None, image_std=None, **kwargs)[source]

Constructs a BEiT feature extractor.

This feature extractor inherits from FeatureExtractionMixin which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

Parameters
  • do_resize (bool, optional, defaults to True) – Whether to resize the input to a certain size.

  • size (int or Tuple(int), optional, defaults to 256) – Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an integer is provided, then the input will be resized to (size, size). Only has an effect if do_resize is set to True.

  • resample (int, optional, defaults to PIL.Image.BICUBIC) – An optional resampling filter. This can be one of PIL.Image.NEAREST, PIL.Image.BOX, PIL.Image.BILINEAR, PIL.Image.HAMMING, PIL.Image.BICUBIC or PIL.Image.LANCZOS. Only has an effect if do_resize is set to True.

  • do_center_crop (bool, optional, defaults to True) – Whether to crop the input at the center. If the input size is smaller than crop_size along any edge, the image is padded with 0’s and then center cropped.

  • crop_size (int, optional, defaults to 224) – Desired output size when applying center-cropping. Only has an effect if do_center_crop is set to True.

  • do_normalize (bool, optional, defaults to True) – Whether or not to normalize the input with image_mean and image_std.

  • image_mean (List[int], defaults to [0.5, 0.5, 0.5]) – The sequence of means for each channel, to be used when normalizing images.

  • image_std (List[int], defaults to [0.5, 0.5, 0.5]) – The sequence of standard deviations for each channel, to be used when normalizing images.

__call__(images: Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, List[PIL.Image.Image], List[numpy.ndarray], List[torch.Tensor]], return_tensors: Optional[Union[str, transformers.file_utils.TensorType]] = None, **kwargs) → transformers.feature_extraction_utils.BatchFeature[source]

Main method to prepare for the model one or several image(s).

Warning

NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass PIL images.

Parameters
  • images (PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray], List[torch.Tensor]) – The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width.

  • return_tensors (str or TensorType, optional, defaults to 'np') –

    If set, will return tensors of a particular framework. Acceptable values are:

    • 'tf': Return TensorFlow tf.constant objects.

    • 'pt': Return PyTorch torch.Tensor objects.

    • 'np': Return NumPy np.ndarray objects.

    • 'jax': Return JAX jnp.ndarray objects.

Returns

A BatchFeature with the following fields:

  • pixel_values – Pixel values to be fed to a model, of shape (batch_size, num_channels, height, width).

Return type

BatchFeature

BeitModel

class transformers.BeitModel(config, add_pooling_layer=True)[source]

The bare Beit Model transformer outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Parameters

config (BeitConfig) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

forward(pixel_values=None, bool_masked_pos=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None)[source]

The BeitModel forward method, overrides the __call__() special method.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Parameters
  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)) – Pixel values. Pixel values can be obtained using BeitFeatureExtractor. See transformers.BeitFeatureExtractor.__call__() for details.

  • head_mask (torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads), optional) –

    Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]:

    • 1 indicates the head is not masked,

    • 0 indicates the head is masked.

  • output_attentions (bool, optional) – Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.

  • output_hidden_states (bool, optional) – Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.

  • return_dict (bool, optional) – Whether or not to return a ModelOutput instead of a plain tuple.

Returns

A BaseModelOutputWithPooling or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (BeitConfig) and inputs.

  • last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)) – Sequence of hidden-states at the output of the last layer of the model.

  • pooler_output (torch.FloatTensor of shape (batch_size, hidden_size)) – Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns the classification token after processing through a linear layer and a tanh activation function. The linear layer weights are trained from the next sentence prediction (classification) objective during pretraining.

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) – Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) – Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

Examples:

>>> from transformers import BeitFeatureExtractor, BeitModel
>>> from PIL import Image
>>> import requests

>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
>>> model = BeitModel.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')

>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state

Return type

BaseModelOutputWithPooling or tuple(torch.FloatTensor)

BeitForMaskedImageModeling

class transformers.BeitForMaskedImageModeling(config)[source]

Beit Model transformer with a ‘language’ modeling head on top (to predict visual tokens). This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Parameters

config (BeitConfig) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

forward(pixel_values=None, bool_masked_pos=None, head_mask=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None)[source]

The BeitForMaskedImageModeling forward method, overrides the __call__() special method.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Parameters
  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)) – Pixel values. Pixel values can be obtained using BeitFeatureExtractor. See transformers.BeitFeatureExtractor.__call__() for details.

  • head_mask (torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads), optional) –

    Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]:

    • 1 indicates the head is not masked,

    • 0 indicates the head is masked.

  • output_attentions (bool, optional) – Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.

  • output_hidden_states (bool, optional) – Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.

  • return_dict (bool, optional) – Whether or not to return a ModelOutput instead of a plain tuple.

  • bool_masked_pos (torch.BoolTensor of shape (batch_size, num_patches)) – Boolean masked positions. Indicates which patches are masked (1) and which aren’t (0).

  • labels (torch.LongTensor of shape (batch_size,), optional) – Labels for computing the image classification/regression loss. Indices should be in [0, ..., config.num_labels - 1]. If config.num_labels == 1 a regression loss is computed (Mean-Square loss), If config.num_labels > 1 a classification loss is computed (Cross-Entropy).

Returns

A MaskedLMOutput or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (BeitConfig) and inputs.

  • loss (torch.FloatTensor of shape (1,), optional, returned when labels is provided) – Masked language modeling (MLM) loss.

  • logits (torch.FloatTensor of shape (batch_size, sequence_length, config.vocab_size)) – Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) – Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) – Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

Examples:

>>> from transformers import BeitFeatureExtractor, BeitForMaskedImageModeling
>>> from PIL import Image
>>> import requests

>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k')
>>> model = BeitForMaskedImageModeling.from_pretrained('microsoft/beit-base-patch16-224-pt22k')

>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits

Return type

MaskedLMOutput or tuple(torch.FloatTensor)

BeitForImageClassification

class transformers.BeitForImageClassification(config)[source]

Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final hidden states of the patch tokens) e.g. for ImageNet.

This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Parameters

config (BeitConfig) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

forward(pixel_values=None, head_mask=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None)[source]

The BeitForImageClassification forward method, overrides the __call__() special method.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Parameters
  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)) – Pixel values. Pixel values can be obtained using BeitFeatureExtractor. See transformers.BeitFeatureExtractor.__call__() for details.

  • head_mask (torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads), optional) –

    Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]:

    • 1 indicates the head is not masked,

    • 0 indicates the head is masked.

  • output_attentions (bool, optional) – Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.

  • output_hidden_states (bool, optional) – Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.

  • return_dict (bool, optional) – Whether or not to return a ModelOutput instead of a plain tuple.

  • labels (torch.LongTensor of shape (batch_size,), optional) – Labels for computing the image classification/regression loss. Indices should be in [0, ..., config.num_labels - 1]. If config.num_labels == 1 a regression loss is computed (Mean-Square loss), If config.num_labels > 1 a classification loss is computed (Cross-Entropy).

Returns

A SequenceClassifierOutput or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (BeitConfig) and inputs.

  • loss (torch.FloatTensor of shape (1,), optional, returned when labels is provided) – Classification (or regression if config.num_labels==1) loss.

  • logits (torch.FloatTensor of shape (batch_size, config.num_labels)) – Classification (or regression if config.num_labels==1) scores (before SoftMax).

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) – Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) – Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

Examples:

>>> from transformers import BeitFeatureExtractor, BeitForImageClassification
>>> from PIL import Image
>>> import requests

>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
>>> model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')

>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])

Return type

SequenceClassifierOutput or tuple(torch.FloatTensor)