DINOv2 with Registers
Overview
The DINOv2 with Registers model was proposed in Vision Transformers Need Registers by TimothΓ©e Darcet, Maxime Oquab, Julien Mairal, Piotr Bojanowski.
The Vision Transformer (ViT) is a transformer encoder model (BERT-like) originally introduced to do supervised image classification on ImageNet.
Next, people figured out ways to make ViT work really well on self-supervised image feature extraction (i.e. learning meaningful features, also called embeddings) on images without requiring any labels. Some example papers here include DINOv2 and MAE.
The authors of DINOv2 noticed that ViTs have artifacts in attention maps. Itβs due to the model using some image patches as βregistersβ. The authors propose a fix: just add some new tokens (called βregisterβ tokens), which you only use during pre-training (and throw away afterwards). This results in:
- no artifacts
- interpretable attention maps
- and improved performances.
The abstract from the paper is the following:
Transformers have recently emerged as a powerful tool for learning visual representations. In this paper, we identify and characterize artifacts in feature maps of both supervised and self-supervised ViT networks. The artifacts correspond to high-norm tokens appearing during inference primarily in low-informative background areas of images, that are repurposed for internal computations. We propose a simple yet effective solution based on providing additional tokens to the input sequence of the Vision Transformer to fill that role. We show that this solution fixes that problem entirely for both supervised and self-supervised models, sets a new state of the art for self-supervised visual models on dense visual prediction tasks, enables object discovery methods with larger models, and most importantly leads to smoother feature maps and attention maps for downstream visual processing.

Tips:
- Usage of DINOv2 with Registers is identical to DINOv2 without, youβll just get better performance.
This model was contributed by nielsr. The original code can be found here.
Dinov2WithRegistersConfig
class transformers.Dinov2WithRegistersConfig
< source >( hidden_size = 768 num_hidden_layers = 12 num_attention_heads = 12 mlp_ratio = 4 hidden_act = 'gelu' hidden_dropout_prob = 0.0 attention_probs_dropout_prob = 0.0 initializer_range = 0.02 layer_norm_eps = 1e-06 image_size = 224 patch_size = 16 num_channels = 3 qkv_bias = True layerscale_value = 1.0 drop_path_rate = 0.0 use_swiglu_ffn = False num_register_tokens = 4 out_features = None out_indices = None apply_layernorm = True reshape_hidden_states = True **kwargs )
Parameters
- hidden_size (
int
, optional, defaults to 768) — Dimensionality of the encoder layers and the pooler layer. - num_hidden_layers (
int
, optional, defaults to 12) — Number of hidden layers in the Transformer encoder. - num_attention_heads (
int
, optional, defaults to 12) — Number of attention heads for each attention layer in the Transformer encoder. - mlp_ratio (
int
, optional, defaults to 4) — Ratio of the hidden size of the MLPs relative to thehidden_size
. - hidden_act (
str
orfunction
, optional, defaults to"gelu"
) — The non-linear activation function (function or string) in the encoder and pooler. If string,"gelu"
,"relu"
,"selu"
and"gelu_new"
are supported. - hidden_dropout_prob (
float
, optional, defaults to 0.0) — The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. - attention_probs_dropout_prob (
float
, optional, defaults to 0.0) — The dropout ratio for the attention probabilities. - initializer_range (
float
, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - layer_norm_eps (
float
, optional, defaults to 1e-06) — The epsilon used by the layer normalization layers. - image_size (
int
, optional, defaults to 224) — The size (resolution) of each image. - patch_size (
int
, optional, defaults to 16) — The size (resolution) of each patch. - num_channels (
int
, optional, defaults to 3) — The number of input channels. - qkv_bias (
bool
, optional, defaults toTrue
) — Whether to add a bias to the queries, keys and values. - layerscale_value (
float
, optional, defaults to 1.0) — Initial value to use for layer scale. - drop_path_rate (
float
, optional, defaults to 0.0) — Stochastic depth rate per sample (when applied in the main path of residual layers). - use_swiglu_ffn (
bool
, optional, defaults toFalse
) — Whether to use the SwiGLU feedforward neural network. - num_register_tokens (
int
, optional, defaults to 4) — Number of register tokens to use. - out_features (
List[str]
, optional) — If used as backbone, list of features to output. Can be any of"stem"
,"stage1"
,"stage2"
, etc. (depending on how many stages the model has). If unset andout_indices
is set, will default to the corresponding stages. If unset andout_indices
is unset, will default to the last stage. Must be in the same order as defined in thestage_names
attribute. - out_indices (
List[int]
, optional) — If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how many stages the model has). If unset andout_features
is set, will default to the corresponding stages. If unset andout_features
is unset, will default to the last stage. Must be in the same order as defined in thestage_names
attribute. - apply_layernorm (
bool
, optional, defaults toTrue
) — Whether to apply layer normalization to the feature maps in case the model is used as backbone. - reshape_hidden_states (
bool
, optional, defaults toTrue
) — Whether to reshape the feature maps to 4D tensors of shape(batch_size, hidden_size, height, width)
in case the model is used as backbone. IfFalse
, the feature maps will be 3D tensors of shape(batch_size, seq_len, hidden_size)
.
This is the configuration class to store the configuration of a Dinov2WithRegistersModel. It is used to instantiate an Dinov2WithRegisters model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the DINOv2 with Registers facebook/dinov2-with-registers-base architecture.
Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.
Example:
>>> from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel
>>> # Initializing a Dinov2WithRegisters base style configuration
>>> configuration = Dinov2WithRegistersConfig()
>>> # Initializing a model (with random weights) from the base style configuration
>>> model = Dinov2WithRegistersModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Dinov2WithRegistersModel
class transformers.Dinov2WithRegistersModel
< source >( config: Dinov2WithRegistersConfig )
Parameters
- config (Dinov2WithRegistersConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
The bare Dinov2WithRegisters Model transformer outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( pixel_values: typing.Optional[torch.Tensor] = None bool_masked_pos: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) β transformers.modeling_outputs.BaseModelOutputWithPooling or tuple(torch.FloatTensor)
Parameters
- pixel_values (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See BitImageProcessor.preprocess() for details. - bool_masked_pos (
torch.BoolTensor
of shape(batch_size, sequence_length)
) — Boolean masked positions. Indicates which patches are masked (1) and which aren’t (0). Only relevant for pre-training. - head_mask (
torch.FloatTensor
of shape(num_heads,)
or(num_layers, num_heads)
, optional) — Mask to nullify selected heads of the self-attention modules. Mask values selected in[0, 1]
:- 1 indicates the head is not masked,
- 0 indicates the head is masked.
- output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
Returns
transformers.modeling_outputs.BaseModelOutputWithPooling or tuple(torch.FloatTensor)
A transformers.modeling_outputs.BaseModelOutputWithPooling or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (Dinov2WithRegistersConfig) and inputs.
-
last_hidden_state (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
) β Sequence of hidden-states at the output of the last layer of the model. -
pooler_output (
torch.FloatTensor
of shape(batch_size, hidden_size)
) β Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns the classification token after processing through a linear layer and a tanh activation function. The linear layer weights are trained from the next sentence prediction (classification) objective during pretraining. -
hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) β Tuple oftorch.FloatTensor
(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
-
attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) β Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
The Dinov2WithRegistersModel forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Example:
>>> from transformers import AutoImageProcessor, Dinov2WithRegistersModel
>>> import torch
>>> from datasets import load_dataset
>>> dataset = load_dataset("huggingface/cats-image", trust_remote_code=True)
>>> image = dataset["test"]["image"][0]
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2_with_registers-base")
>>> model = Dinov2WithRegistersModel.from_pretrained("facebook/dinov2_with_registers-base")
>>> inputs = image_processor(image, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
[1, 257, 768]
Dinov2WithRegistersForImageClassification
class transformers.Dinov2WithRegistersForImageClassification
< source >( config: Dinov2WithRegistersConfig )
Parameters
- config (Dinov2WithRegistersConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
Dinov2WithRegisters Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet.
This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( pixel_values: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None labels: typing.Optional[torch.Tensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) β transformers.modeling_outputs.ImageClassifierOutput or tuple(torch.FloatTensor)
Parameters
- pixel_values (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See BitImageProcessor.preprocess() for details. - head_mask (
torch.FloatTensor
of shape(num_heads,)
or(num_layers, num_heads)
, optional) — Mask to nullify selected heads of the self-attention modules. Mask values selected in[0, 1]
:- 1 indicates the head is not masked,
- 0 indicates the head is masked.
- output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - labels (
torch.LongTensor
of shape(batch_size,)
, optional) — Labels for computing the image classification/regression loss. Indices should be in[0, ..., config.num_labels - 1]
. Ifconfig.num_labels == 1
a regression loss is computed (Mean-Square loss), Ifconfig.num_labels > 1
a classification loss is computed (Cross-Entropy).
Returns
transformers.modeling_outputs.ImageClassifierOutput or tuple(torch.FloatTensor)
A transformers.modeling_outputs.ImageClassifierOutput or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (Dinov2WithRegistersConfig) and inputs.
-
loss (
torch.FloatTensor
of shape(1,)
, optional, returned whenlabels
is provided) β Classification (or regression if config.num_labels==1) loss. -
logits (
torch.FloatTensor
of shape(batch_size, config.num_labels)
) β Classification (or regression if config.num_labels==1) scores (before SoftMax). -
hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) β Tuple oftorch.FloatTensor
(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each stage) of shape(batch_size, sequence_length, hidden_size)
. Hidden-states (also called feature maps) of the model at the output of each stage. -
attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) β Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, patch_size, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
The Dinov2WithRegistersForImageClassification forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Example:
>>> from transformers import AutoImageProcessor, Dinov2WithRegistersForImageClassification
>>> import torch
>>> from datasets import load_dataset
>>> dataset = load_dataset("huggingface/cats-image", trust_remote_code=True)
>>> image = dataset["test"]["image"][0]
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2_with_registers-small-imagenet1k-1-layer")
>>> model = Dinov2WithRegistersForImageClassification.from_pretrained("facebook/dinov2_with_registers-small-imagenet1k-1-layer")
>>> inputs = image_processor(image, return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_label = logits.argmax(-1).item()
>>> print(model.config.id2label[predicted_label])
tabby, tabby cat