Vision Transformer (ViT)¶
Note
This is a recently introduced model so the API hasn’t been tested extensively. There may be some bugs or slight breaking changes to fix it in the future. If you see something strange, file a Github Issue.
Overview¶
The Vision Transformer (ViT) model was proposed in An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. It’s the first paper that successfully trains a Transformer encoder on ImageNet, attaining very good results compared to familiar convolutional architectures.
The abstract from the paper is the following:
While the Transformer architecture has become the de-facto standard for natural language processing tasks, its applications to computer vision remain limited. In vision, attention is either applied in conjunction with convolutional networks, or used to replace certain components of convolutional networks while keeping their overall structure in place. We show that this reliance on CNNs is not necessary and a pure transformer applied directly to sequences of image patches can perform very well on image classification tasks. When pre-trained on large amounts of data and transferred to multiple mid-sized or small image recognition benchmarks (ImageNet, CIFAR-100, VTAB, etc.), Vision Transformer (ViT) attains excellent results compared to state-of-the-art convolutional networks while requiring substantially fewer computational resources to train.
Tips:
To feed images to the Transformer encoder, each image is split into a sequence of fixed-size non-overlapping patches, which are then linearly embedded. A [CLS] token is added to serve as representation of an entire image, which can be used for classification. The authors also add absolute position embeddings, and feed the resulting sequence of vectors to a standard Transformer encoder.
As the Vision Transformer expects each image to be of the same size (resolution), one can use
ViTFeatureExtractor
to resize (or rescale) and normalize images for the model.Both the patch resolution and image resolution used during pre-training or fine-tuning are reflected in the name of each checkpoint. For example,
google/vit-base-patch16-224
refers to a base-sized architecture with patch resolution of 16x16 and fine-tuning resolution of 224x224. All checkpoints can be found on the hub.The available checkpoints are either (1) pre-trained on ImageNet-21k (a collection of 14 million images and 21k classes) only, or (2) also fine-tuned on ImageNet (also referred to as ILSVRC 2012, a collection of 1.3 million images and 1,000 classes).
The Vision Transformer was pre-trained using a resolution of 224x224. During fine-tuning, it is often beneficial to use a higher resolution than pre-training (Touvron et al., 2019), (Kolesnikov et al., 2020). In order to fine-tune at higher resolution, the authors perform 2D interpolation of the pre-trained position embeddings, according to their location in the original image.
The best results are obtained with supervised pre-training, which is not the case in NLP. The authors also performed an experiment with a self-supervised pre-training objective, namely masked patched prediction (inspired by masked language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant improvement of 2% to training from scratch, but still 4% behind supervised pre-training.
Following the original Vision Transformer, some follow-up works have been made:
DeiT (Data-efficient Image Transformers) by Facebook AI. DeiT models are distilled vision transformers. Refer to DeiT’s documentation page. The authors of DeiT also released more efficiently trained ViT models, which you can directly plug into
ViTModel
orViTForImageClassification
. There are 4 variants available (in 3 different sizes): facebook/deit-tiny-patch16-224, facebook/deit-small-patch16-224, facebook/deit-base-patch16-224 and facebook/deit-base-patch16-384. Note that one should useDeiTFeatureExtractor
in order to prepare images for the model.BEiT (BERT pre-training of Image Transformers) by Microsoft Research. BEiT models outperform supervised pre-trained vision transformers using a self-supervised method inspired by BERT (masked image modeling) and based on a VQ-VAE. Refer to BEiT’s documentation page.
DINO (a method for self-supervised training of Vision Transformers) by Facebook AI. Vision Transformers trained using the DINO method show very interesting properties not seen with convolutional models. They are capable of segmenting objects, without having ever been trained to do so. DINO checkpoints can be found on the hub.
This model was contributed by nielsr. The original code (written in JAX) can be found here.
Note that we converted the weights from Ross Wightman’s timm library, who already converted the weights from JAX to PyTorch. Credits go to him!
ViTConfig¶
-
class
transformers.
ViTConfig
(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act='gelu', hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, initializer_range=0.02, layer_norm_eps=1e-12, is_encoder_decoder=False, image_size=224, patch_size=16, num_channels=3, qkv_bias=True, **kwargs)[source]¶ This is the configuration class to store the configuration of a
ViTModel
. It is used to instantiate an ViT model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the ViT google/vit-base-patch16-224 architecture.Configuration objects inherit from
PretrainedConfig
and can be used to control the model outputs. Read the documentation fromPretrainedConfig
for more information.- Parameters
hidden_size (
int
, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.num_hidden_layers (
int
, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.num_attention_heads (
int
, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.intermediate_size (
int
, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.hidden_act (
str
orfunction
, optional, defaults to"gelu"
) – The non-linear activation function (function or string) in the encoder and pooler. If string,"gelu"
,"relu"
,"selu"
and"gelu_new"
are supported.hidden_dropout_prob (
float
, optional, defaults to 0.1) – The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.attention_probs_dropout_prob (
float
, optional, defaults to 0.1) – The dropout ratio for the attention probabilities.initializer_range (
float
, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.layer_norm_eps (
float
, optional, defaults to 1e-12) – The epsilon used by the layer normalization layers.image_size (
int
, optional, defaults to224
) – The size (resolution) of each image.patch_size (
int
, optional, defaults to16
) – The size (resolution) of each patch.num_channels (
int
, optional, defaults to3
) – The number of input channels.qkv_bias (
bool
, optional, defaults toTrue
) – Whether to add a bias to the queries, keys and values.
Example:
>>> from transformers import ViTModel, ViTConfig >>> # Initializing a ViT vit-base-patch16-224 style configuration >>> configuration = ViTConfig() >>> # Initializing a model from the vit-base-patch16-224 style configuration >>> model = ViTModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config
ViTFeatureExtractor¶
-
class
transformers.
ViTFeatureExtractor
(do_resize=True, size=224, resample=2, do_normalize=True, image_mean=None, image_std=None, **kwargs)[source]¶ Constructs a ViT feature extractor.
This feature extractor inherits from
FeatureExtractionMixin
which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.- Parameters
do_resize (
bool
, optional, defaults toTrue
) – Whether to resize the input to a certainsize
.size (
int
orTuple(int)
, optional, defaults to 224) – Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an integer is provided, then the input will be resized to (size, size). Only has an effect ifdo_resize
is set toTrue
.resample (
int
, optional, defaults toPIL.Image.BILINEAR
) – An optional resampling filter. This can be one ofPIL.Image.NEAREST
,PIL.Image.BOX
,PIL.Image.BILINEAR
,PIL.Image.HAMMING
,PIL.Image.BICUBIC
orPIL.Image.LANCZOS
. Only has an effect ifdo_resize
is set toTrue
.do_normalize (
bool
, optional, defaults toTrue
) – Whether or not to normalize the input with mean and standard deviation.image_mean (
List[int]
, defaults to[0.5, 0.5, 0.5]
) – The sequence of means for each channel, to be used when normalizing images.image_std (
List[int]
, defaults to[0.5, 0.5, 0.5]
) – The sequence of standard deviations for each channel, to be used when normalizing images.
-
__call__
(images: Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, List[PIL.Image.Image], List[numpy.ndarray], List[torch.Tensor]], return_tensors: Optional[Union[str, transformers.file_utils.TensorType]] = None, **kwargs) → transformers.feature_extraction_utils.BatchFeature[source]¶ Main method to prepare for the model one or several image(s).
Warning
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass PIL images.
- Parameters
images (
PIL.Image.Image
,np.ndarray
,torch.Tensor
,List[PIL.Image.Image]
,List[np.ndarray]
,List[torch.Tensor]
) – The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width.return_tensors (
str
orTensorType
, optional, defaults to'np'
) –If set, will return tensors of a particular framework. Acceptable values are:
'tf'
: Return TensorFlowtf.constant
objects.'pt'
: Return PyTorchtorch.Tensor
objects.'np'
: Return NumPynp.ndarray
objects.'jax'
: Return JAXjnp.ndarray
objects.
- Returns
A
BatchFeature
with the following fields:pixel_values – Pixel values to be fed to a model, of shape (batch_size, num_channels, height, width).
- Return type
ViTModel¶
-
class
transformers.
ViTModel
(config, add_pooling_layer=True)[source]¶ The bare ViT Model transformer outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
- Parameters
config (
ViTConfig
) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out thefrom_pretrained()
method to load the model weights.
-
forward
(pixel_values=None, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, interpolate_pos_encoding=None, return_dict=None)[source]¶ The
ViTModel
forward method, overrides the__call__()
special method.Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.- Parameters
pixel_values (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) – Pixel values. Pixel values can be obtained usingViTFeatureExtractor
. Seetransformers.ViTFeatureExtractor.__call__()
for details.head_mask (
torch.FloatTensor
of shape(num_heads,)
or(num_layers, num_heads)
, optional) –Mask to nullify selected heads of the self-attention modules. Mask values selected in
[0, 1]
:1 indicates the head is not masked,
0 indicates the head is masked.
output_attentions (
bool
, optional) – Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail.output_hidden_states (
bool
, optional) – Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail.interpolate_pos_encoding (
bool
, optional) – Whether to interpolate the pre-trained position encodings.return_dict (
bool
, optional) – Whether or not to return aModelOutput
instead of a plain tuple.
- Returns
A
BaseModelOutputWithPooling
or a tuple oftorch.FloatTensor
(ifreturn_dict=False
is passed or whenconfig.return_dict=False
) comprising various elements depending on the configuration (ViTConfig
) and inputs.last_hidden_state (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
) – Sequence of hidden-states at the output of the last layer of the model.pooler_output (
torch.FloatTensor
of shape(batch_size, hidden_size)
) – Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns the classification token after processing through a linear layer and a tanh activation function. The linear layer weights are trained from the next sentence prediction (classification) objective during pretraining.hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) – Tuple oftorch.FloatTensor
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) – Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples:
>>> from transformers import ViTFeatureExtractor, ViTModel >>> from PIL import Image >>> import requests >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' >>> image = Image.open(requests.get(url, stream=True).raw) >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') >>> model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') >>> inputs = feature_extractor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state
- Return type
BaseModelOutputWithPooling
ortuple(torch.FloatTensor)
ViTForImageClassification¶
-
class
transformers.
ViTForImageClassification
(config)[source]¶ ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet.
This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
- Parameters
config (
ViTConfig
) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out thefrom_pretrained()
method to load the model weights.
-
forward
(pixel_values=None, head_mask=None, labels=None, output_attentions=None, output_hidden_states=None, interpolate_pos_encoding=None, return_dict=None)[source]¶ The
ViTForImageClassification
forward method, overrides the__call__()
special method.Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.- Parameters
pixel_values (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) – Pixel values. Pixel values can be obtained usingViTFeatureExtractor
. Seetransformers.ViTFeatureExtractor.__call__()
for details.head_mask (
torch.FloatTensor
of shape(num_heads,)
or(num_layers, num_heads)
, optional) –Mask to nullify selected heads of the self-attention modules. Mask values selected in
[0, 1]
:1 indicates the head is not masked,
0 indicates the head is masked.
output_attentions (
bool
, optional) – Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail.output_hidden_states (
bool
, optional) – Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail.interpolate_pos_encoding (
bool
, optional) – Whether to interpolate the pre-trained position encodings.return_dict (
bool
, optional) – Whether or not to return aModelOutput
instead of a plain tuple.labels (
torch.LongTensor
of shape(batch_size,)
, optional) – Labels for computing the image classification/regression loss. Indices should be in[0, ..., config.num_labels - 1]
. Ifconfig.num_labels == 1
a regression loss is computed (Mean-Square loss), Ifconfig.num_labels > 1
a classification loss is computed (Cross-Entropy).
- Returns
A
SequenceClassifierOutput
or a tuple oftorch.FloatTensor
(ifreturn_dict=False
is passed or whenconfig.return_dict=False
) comprising various elements depending on the configuration (ViTConfig
) and inputs.loss (
torch.FloatTensor
of shape(1,)
, optional, returned whenlabels
is provided) – Classification (or regression if config.num_labels==1) loss.logits (
torch.FloatTensor
of shape(batch_size, config.num_labels)
) – Classification (or regression if config.num_labels==1) scores (before SoftMax).hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) – Tuple oftorch.FloatTensor
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) – Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples:
>>> from transformers import ViTFeatureExtractor, ViTForImageClassification >>> from PIL import Image >>> import requests >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' >>> image = Image.open(requests.get(url, stream=True).raw) >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') >>> model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') >>> inputs = feature_extractor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> logits = outputs.logits >>> # model predicts one of the 1000 ImageNet classes >>> predicted_class_idx = logits.argmax(-1).item() >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
- Return type
SequenceClassifierOutput
ortuple(torch.FloatTensor)
FlaxVitModel¶
-
class
transformers.
FlaxViTModel
(config: transformers.models.vit.configuration_vit.ViTConfig, input_shape=None, seed: int = 0, dtype: numpy.dtype = <class 'jax._src.numpy.lax_numpy.float32'>, **kwargs)[source]¶ The bare ViT Model transformer outputting raw hidden-states without any specific head on top.
This model inherits from
FlaxPreTrainedModel
. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading, saving and converting weights from PyTorch models)This model is also a Flax Linen flax.linen.Module subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- Parameters
config (
ViTConfig
) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out thefrom_pretrained()
method to load the model weights.
-
__call__
(pixel_values, params: dict = None, dropout_rng: jax._src.random.PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None)¶ The
FlaxViTPreTrainedModel
forward method, overrides the__call__()
special method.Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.- Returns
A
FlaxBaseModelOutputWithPooling
or a tuple oftorch.FloatTensor
(ifreturn_dict=False
is passed or whenconfig.return_dict=False
) comprising various elements depending on the configuration (~transformers.
) and inputs.last_hidden_state (
jnp.ndarray
of shape(batch_size, sequence_length, hidden_size)
) – Sequence of hidden-states at the output of the last layer of the model.pooler_output (
jnp.ndarray
of shape(batch_size, hidden_size)
) – Last layer hidden-state of the first token of the sequence (classification token) further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification) objective during pretraining.hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) – Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) – Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
- Return type
FlaxBaseModelOutputWithPooling
ortuple(torch.FloatTensor)
Examples:
>>> from transformers import ViTFeatureExtractor, FlaxViTModel >>> from PIL import Image >>> import requests >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' >>> image = Image.open(requests.get(url, stream=True).raw) >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') >>> model = FlaxViTModel.from_pretrained('google/vit-base-patch16-224-in21k') >>> inputs = feature_extractor(images=image, return_tensors="np") >>> outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state
FlaxViTForImageClassification¶
-
class
transformers.
FlaxViTForImageClassification
(config: transformers.models.vit.configuration_vit.ViTConfig, input_shape=None, seed: int = 0, dtype: numpy.dtype = <class 'jax._src.numpy.lax_numpy.float32'>, **kwargs)[source]¶ ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet.
This model inherits from
FlaxPreTrainedModel
. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading, saving and converting weights from PyTorch models)This model is also a Flax Linen flax.linen.Module subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- Parameters
config (
ViTConfig
) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out thefrom_pretrained()
method to load the model weights.
-
__call__
(pixel_values, params: dict = None, dropout_rng: jax._src.random.PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None)¶ The
FlaxViTPreTrainedModel
forward method, overrides the__call__()
special method.Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.- Returns
A
FlaxSequenceClassifierOutput
or a tuple oftorch.FloatTensor
(ifreturn_dict=False
is passed or whenconfig.return_dict=False
) comprising various elements depending on the configuration (~transformers.
) and inputs.logits (
jnp.ndarray
of shape(batch_size, config.num_labels)
) – Classification (or regression if config.num_labels==1) scores (before SoftMax).hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) – Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) – Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
- Return type
FlaxSequenceClassifierOutput
ortuple(torch.FloatTensor)
Example:
>>> from transformers import ViTFeatureExtractor, FlaxViTForImageClassification >>> from PIL import Image >>> import jax >>> import requests >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' >>> image = Image.open(requests.get(url, stream=True).raw) >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') >>> model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224') >>> inputs = feature_extractor(images=image, return_tensors="np") >>> outputs = model(**inputs) >>> logits = outputs.logits >>> # model predicts one of the 1000 ImageNet classes >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])