Transformers documentation

Swin Transformer

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v4.38.2).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Swin Transformer

Overview

The Swin Transformer was proposed in Swin Transformer: Hierarchical Vision Transformer using Shifted Windows by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.

The abstract from the paper is the following:

This paper presents a new vision Transformer, called Swin Transformer, that capably serves as a general-purpose backbone for computer vision. Challenges in adapting Transformer from language to vision arise from differences between the two domains, such as large variations in the scale of visual entities and the high resolution of pixels in images compared to words in text. To address these differences, we propose a hierarchical Transformer whose representation is computed with \bold{S}hifted \bold{win}dows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection. This hierarchical architecture has the flexibility to model at various scales and has linear computational complexity with respect to image size. These qualities of Swin Transformer make it compatible with a broad range of vision tasks, including image classification (87.3 top-1 accuracy on ImageNet-1K) and dense prediction tasks such as object detection (58.7 box AP and 51.1 mask AP on COCO test-dev) and semantic segmentation (53.5 mIoU on ADE20K val). Its performance surpasses the previous state-of-the-art by a large margin of +2.7 box AP and +2.6 mask AP on COCO, and +3.2 mIoU on ADE20K, demonstrating the potential of Transformer-based models as vision backbones. The hierarchical design and the shifted window approach also prove beneficial for all-MLP architectures.

drawing Swin Transformer architecture. Taken from the original paper.

This model was contributed by novice03. The Tensorflow version of this model was contributed by amyeroberts. The original code can be found here.

Usage tips

  • Swin pads the inputs supporting any input height and width (if divisible by 32).
  • Swin can be used as a backbone. When output_hidden_states = True, it will output both hidden_states and reshaped_hidden_states. The reshaped_hidden_states have a shape of (batch, num_channels, height, width) rather than (batch_size, sequence_length, num_channels).

Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Swin Transformer.

Image Classification

Besides that:

If you’re interested in submitting a resource to be included here, please feel free to open a Pull Request and we’ll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

SwinConfig

class transformers.SwinConfig

< >

( image_size = 224 patch_size = 4 num_channels = 3 embed_dim = 96 depths = [2, 2, 6, 2] num_heads = [3, 6, 12, 24] window_size = 7 mlp_ratio = 4.0 qkv_bias = True hidden_dropout_prob = 0.0 attention_probs_dropout_prob = 0.0 drop_path_rate = 0.1 hidden_act = 'gelu' use_absolute_embeddings = False initializer_range = 0.02 layer_norm_eps = 1e-05 encoder_stride = 32 out_features = None out_indices = None **kwargs )

Parameters

  • image_size (int, optional, defaults to 224) — The size (resolution) of each image.
  • patch_size (int, optional, defaults to 4) — The size (resolution) of each patch.
  • num_channels (int, optional, defaults to 3) — The number of input channels.
  • embed_dim (int, optional, defaults to 96) — Dimensionality of patch embedding.
  • depths (list(int), optional, defaults to [2, 2, 6, 2]) — Depth of each layer in the Transformer encoder.
  • num_heads (list(int), optional, defaults to [3, 6, 12, 24]) — Number of attention heads in each layer of the Transformer encoder.
  • window_size (int, optional, defaults to 7) — Size of windows.
  • mlp_ratio (float, optional, defaults to 4.0) — Ratio of MLP hidden dimensionality to embedding dimensionality.
  • qkv_bias (bool, optional, defaults to True) — Whether or not a learnable bias should be added to the queries, keys and values.
  • hidden_dropout_prob (float, optional, defaults to 0.0) — The dropout probability for all fully connected layers in the embeddings and encoder.
  • attention_probs_dropout_prob (float, optional, defaults to 0.0) — The dropout ratio for the attention probabilities.
  • drop_path_rate (float, optional, defaults to 0.1) — Stochastic depth rate.
  • hidden_act (str or function, optional, defaults to "gelu") — The non-linear activation function (function or string) in the encoder. If string, "gelu", "relu", "selu" and "gelu_new" are supported.
  • use_absolute_embeddings (bool, optional, defaults to False) — Whether or not to add absolute position embeddings to the patch embeddings.
  • initializer_range (float, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  • layer_norm_eps (float, optional, defaults to 1e-05) — The epsilon used by the layer normalization layers.
  • encoder_stride (int, optional, defaults to 32) — Factor to increase the spatial resolution by in the decoder head for masked image modeling.
  • out_features (List[str], optional) — If used as backbone, list of features to output. Can be any of "stem", "stage1", "stage2", etc. (depending on how many stages the model has). If unset and out_indices is set, will default to the corresponding stages. If unset and out_indices is unset, will default to the last stage. Must be in the same order as defined in the stage_names attribute.
  • out_indices (List[int], optional) — If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how many stages the model has). If unset and out_features is set, will default to the corresponding stages. If unset and out_features is unset, will default to the last stage. Must be in the same order as defined in the stage_names attribute.

This is the configuration class to store the configuration of a SwinModel. It is used to instantiate a Swin model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Swin microsoft/swin-tiny-patch4-window7-224 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

>>> from transformers import SwinConfig, SwinModel

>>> # Initializing a Swin microsoft/swin-tiny-patch4-window7-224 style configuration
>>> configuration = SwinConfig()

>>> # Initializing a model (with random weights) from the microsoft/swin-tiny-patch4-window7-224 style configuration
>>> model = SwinModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
Pytorch
Hide Pytorch content

SwinModel

class transformers.SwinModel

< >

( config add_pooling_layer = True use_mask_token = False )

Parameters

  • config (SwinConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
  • add_pooling_layer (bool, optional, defaults to True) — Whether or not to apply pooling layer.
  • use_mask_token (bool, optional, defaults to False) — Whether or not to create and apply mask tokens in the embedding layer.

The bare Swin Model transformer outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( pixel_values: Optional = None bool_masked_pos: Optional = None head_mask: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) β†’ transformers.models.swin.modeling_swin.SwinModelOutput or tuple(torch.FloatTensor)

Parameters

  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See ViTImageProcessor.call() for details.
  • head_mask (torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads), optional) — Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]:

    • 1 indicates the head is not masked,
    • 0 indicates the head is masked.
  • output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
  • bool_masked_pos (torch.BoolTensor of shape (batch_size, num_patches), optional) — Boolean masked positions. Indicates which patches are masked (1) and which aren’t (0).

Returns

transformers.models.swin.modeling_swin.SwinModelOutput or tuple(torch.FloatTensor)

A transformers.models.swin.modeling_swin.SwinModelOutput or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (SwinConfig) and inputs.

  • last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)) β€” Sequence of hidden-states at the output of the last layer of the model.

  • pooler_output (torch.FloatTensor of shape (batch_size, hidden_size), optional, returned when add_pooling_layer=True is passed) β€” Average pooling of the last layer hidden-state.

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each stage) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) β€” Tuple of torch.FloatTensor (one for each stage) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • reshaped_hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each stage) of shape (batch_size, hidden_size, height, width).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions.

The SwinModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> from transformers import AutoImageProcessor, SwinModel
>>> import torch
>>> from datasets import load_dataset

>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]

>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
>>> model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")

>>> inputs = image_processor(image, return_tensors="pt")

>>> with torch.no_grad():
...     outputs = model(**inputs)

>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
[1, 49, 768]

SwinForMaskedImageModeling

class transformers.SwinForMaskedImageModeling

< >

( config )

Parameters

  • config (SwinConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

Swin Model with a decoder on top for masked image modeling, as proposed in SimMIM.

Note that we provide a script to pre-train this model on custom data in our examples directory.

This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( pixel_values: Optional = None bool_masked_pos: Optional = None head_mask: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) β†’ transformers.models.swin.modeling_swin.SwinMaskedImageModelingOutput or tuple(torch.FloatTensor)

Parameters

  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See ViTImageProcessor.call() for details.
  • head_mask (torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads), optional) — Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]:

    • 1 indicates the head is not masked,
    • 0 indicates the head is masked.
  • output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
  • bool_masked_pos (torch.BoolTensor of shape (batch_size, num_patches)) — Boolean masked positions. Indicates which patches are masked (1) and which aren’t (0).

Returns

transformers.models.swin.modeling_swin.SwinMaskedImageModelingOutput or tuple(torch.FloatTensor)

A transformers.models.swin.modeling_swin.SwinMaskedImageModelingOutput or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (SwinConfig) and inputs.

  • loss (torch.FloatTensor of shape (1,), optional, returned when bool_masked_pos is provided) β€” Masked image modeling (MLM) loss.

  • reconstruction (torch.FloatTensor of shape (batch_size, num_channels, height, width)) β€” Reconstructed pixel values.

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each stage) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) β€” Tuple of torch.FloatTensor (one for each stage) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • reshaped_hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each stage) of shape (batch_size, hidden_size, height, width).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions.

The SwinForMaskedImageModeling forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

>>> from transformers import AutoImageProcessor, SwinForMaskedImageModeling
>>> import torch
>>> from PIL import Image
>>> import requests

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-base-simmim-window6-192")
>>> model = SwinForMaskedImageModeling.from_pretrained("microsoft/swin-base-simmim-window6-192")

>>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
>>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
>>> # create random boolean mask of shape (batch_size, num_patches)
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape)
[1, 3, 192, 192]

SwinForImageClassification

class transformers.SwinForImageClassification

< >

( config )

Parameters

  • config (SwinConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet.

This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( pixel_values: Optional = None head_mask: Optional = None labels: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) β†’ transformers.models.swin.modeling_swin.SwinImageClassifierOutput or tuple(torch.FloatTensor)

Parameters

  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See ViTImageProcessor.call() for details.
  • head_mask (torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads), optional) — Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]:

    • 1 indicates the head is not masked,
    • 0 indicates the head is masked.
  • output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
  • labels (torch.LongTensor of shape (batch_size,), optional) — Labels for computing the image classification/regression loss. Indices should be in [0, ..., config.num_labels - 1]. If config.num_labels == 1 a regression loss is computed (Mean-Square loss), If config.num_labels > 1 a classification loss is computed (Cross-Entropy).

Returns

transformers.models.swin.modeling_swin.SwinImageClassifierOutput or tuple(torch.FloatTensor)

A transformers.models.swin.modeling_swin.SwinImageClassifierOutput or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (SwinConfig) and inputs.

  • loss (torch.FloatTensor of shape (1,), optional, returned when labels is provided) β€” Classification (or regression if config.num_labels==1) loss.

  • logits (torch.FloatTensor of shape (batch_size, config.num_labels)) β€” Classification (or regression if config.num_labels==1) scores (before SoftMax).

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each stage) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) β€” Tuple of torch.FloatTensor (one for each stage) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • reshaped_hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each stage) of shape (batch_size, hidden_size, height, width).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions.

The SwinForImageClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> from transformers import AutoImageProcessor, SwinForImageClassification
>>> import torch
>>> from datasets import load_dataset

>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]

>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
>>> model = SwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")

>>> inputs = image_processor(image, return_tensors="pt")

>>> with torch.no_grad():
...     logits = model(**inputs).logits

>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_label = logits.argmax(-1).item()
>>> print(model.config.id2label[predicted_label])
tabby, tabby cat
TensorFlow
Hide TensorFlow content

TFSwinModel

class transformers.TFSwinModel

< >

( config: SwinConfig add_pooling_layer: bool = True use_mask_token: bool = False **kwargs )

Parameters

  • config (SwinConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

The bare Swin Model transformer outputting raw hidden-states without any specific head on top. This model is a Tensorflow keras.layers.Layer sub-class. Use it as a regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and behavior.

call

< >

( pixel_values: tf.Tensor | None = None bool_masked_pos: tf.Tensor | None = None head_mask: tf.Tensor | None = None output_attentions: Optional[bool] = None output_hidden_states: Optional[bool] = None return_dict: Optional[bool] = None training: bool = False ) β†’ transformers.models.swin.modeling_tf_swin.TFSwinModelOutput or tuple(tf.Tensor)

Parameters

  • pixel_values (tf.Tensor of shape (batch_size, num_channels, height, width)) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See ViTImageProcessor.call() for details.
  • head_mask (tf.Tensor of shape (num_heads,) or (num_layers, num_heads), optional) — Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]:

    • 1 indicates the head is not masked,
    • 0 indicates the head is masked.
  • output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
  • bool_masked_pos (tf.Tensor of shape (batch_size, num_patches), optional) — Boolean masked positions. Indicates which patches are masked (1) and which aren’t (0).

Returns

transformers.models.swin.modeling_tf_swin.TFSwinModelOutput or tuple(tf.Tensor)

A transformers.models.swin.modeling_tf_swin.TFSwinModelOutput or a tuple of tf.Tensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (SwinConfig) and inputs.

  • last_hidden_state (tf.Tensor of shape (batch_size, sequence_length, hidden_size)) β€” Sequence of hidden-states at the output of the last layer of the model.

  • pooler_output (tf.Tensor of shape (batch_size, hidden_size), optional, returned when add_pooling_layer=True is passed) β€” Average pooling of the last layer hidden-state.

  • hidden_states (tuple(tf.Tensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of tf.Tensor (one for the output of the embeddings + one for the output of each stage) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(tf.Tensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) β€” Tuple of tf.Tensor (one for each stage) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • reshaped_hidden_states (tuple(tf.Tensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of tf.Tensor (one for the output of the embeddings + one for the output of each stage) of shape (batch_size, hidden_size, height, width).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions.

The TFSwinModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> from transformers import AutoImageProcessor, TFSwinModel
>>> from datasets import load_dataset

>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]

>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
>>> model = TFSwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")

>>> inputs = image_processor(image, return_tensors="tf")
>>> outputs = model(**inputs)

>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
[1, 49, 768]

TFSwinForMaskedImageModeling

class transformers.TFSwinForMaskedImageModeling

< >

( config: SwinConfig )

Parameters

  • config (SwinConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

Swin Model with a decoder on top for masked image modeling, as proposed in SimMIM. This model is a Tensorflow keras.layers.Layer sub-class. Use it as a regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and behavior.

call

< >

( pixel_values: tf.Tensor | None = None bool_masked_pos: tf.Tensor | None = None head_mask: tf.Tensor | None = None output_attentions: Optional[bool] = None output_hidden_states: Optional[bool] = None return_dict: Optional[bool] = None training: bool = False ) β†’ transformers.models.swin.modeling_tf_swin.TFSwinMaskedImageModelingOutput or tuple(tf.Tensor)

Parameters

  • pixel_values (tf.Tensor of shape (batch_size, num_channels, height, width)) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See ViTImageProcessor.call() for details.
  • head_mask (tf.Tensor of shape (num_heads,) or (num_layers, num_heads), optional) — Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]:

    • 1 indicates the head is not masked,
    • 0 indicates the head is masked.
  • output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
  • bool_masked_pos (tf.Tensor of shape (batch_size, num_patches)) — Boolean masked positions. Indicates which patches are masked (1) and which aren’t (0).

Returns

transformers.models.swin.modeling_tf_swin.TFSwinMaskedImageModelingOutput or tuple(tf.Tensor)

A transformers.models.swin.modeling_tf_swin.TFSwinMaskedImageModelingOutput or a tuple of tf.Tensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (SwinConfig) and inputs.

  • loss (tf.Tensor of shape (1,), optional, returned when bool_masked_pos is provided) β€” Masked image modeling (MLM) loss.

  • reconstruction (tf.Tensor of shape (batch_size, num_channels, height, width)) β€” Reconstructed pixel values.

  • hidden_states (tuple(tf.Tensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of tf.Tensor (one for the output of the embeddings + one for the output of each stage) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(tf.Tensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) β€” Tuple of tf.Tensor (one for each stage) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • reshaped_hidden_states (tuple(tf.Tensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of tf.Tensor (one for the output of the embeddings + one for the output of each stage) of shape (batch_size, hidden_size, height, width).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions.

The TFSwinForMaskedImageModeling forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

>>> from transformers import AutoImageProcessor, TFSwinForMaskedImageModeling
>>> import tensorflow as tf
>>> from PIL import Image
>>> import requests

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
>>> model = TFSwinForMaskedImageModeling.from_pretrained("microsoft/swin-tiny-patch4-window7-224")

>>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
>>> pixel_values = image_processor(images=image, return_tensors="tf").pixel_values
>>> # create random boolean mask of shape (batch_size, num_patches)
>>> bool_masked_pos = tf.random.uniform((1, num_patches)) >= 0.5

>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape)
[1, 3, 224, 224]

TFSwinForImageClassification

class transformers.TFSwinForImageClassification

< >

( config: SwinConfig )

Parameters

  • config (SwinConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet.

This model is a Tensorflow keras.layers.Layer sub-class. Use it as a regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and behavior.

call

< >

( pixel_values: tf.Tensor | None = None head_mask: tf.Tensor | None = None labels: tf.Tensor | None = None output_attentions: Optional[bool] = None output_hidden_states: Optional[bool] = None return_dict: Optional[bool] = None training: bool = False ) β†’ transformers.models.swin.modeling_tf_swin.TFSwinImageClassifierOutput or tuple(tf.Tensor)

Parameters

  • pixel_values (tf.Tensor of shape (batch_size, num_channels, height, width)) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See ViTImageProcessor.call() for details.
  • head_mask (tf.Tensor of shape (num_heads,) or (num_layers, num_heads), optional) — Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]:

    • 1 indicates the head is not masked,
    • 0 indicates the head is masked.
  • output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
  • labels (tf.Tensor of shape (batch_size,), optional) — Labels for computing the image classification/regression loss. Indices should be in [0, ..., config.num_labels - 1]. If config.num_labels == 1 a regression loss is computed (Mean-Square loss), If config.num_labels > 1 a classification loss is computed (Cross-Entropy).

Returns

transformers.models.swin.modeling_tf_swin.TFSwinImageClassifierOutput or tuple(tf.Tensor)

A transformers.models.swin.modeling_tf_swin.TFSwinImageClassifierOutput or a tuple of tf.Tensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (SwinConfig) and inputs.

  • loss (tf.Tensor of shape (1,), optional, returned when labels is provided) β€” Classification (or regression if config.num_labels==1) loss.

  • logits (tf.Tensor of shape (batch_size, config.num_labels)) β€” Classification (or regression if config.num_labels==1) scores (before SoftMax).

  • hidden_states (tuple(tf.Tensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of tf.Tensor (one for the output of the embeddings + one for the output of each stage) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(tf.Tensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) β€” Tuple of tf.Tensor (one for each stage) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • reshaped_hidden_states (tuple(tf.Tensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of tf.Tensor (one for the output of the embeddings + one for the output of each stage) of shape (batch_size, hidden_size, height, width).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions.

The TFSwinForImageClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> from transformers import AutoImageProcessor, TFSwinForImageClassification
>>> import tensorflow as tf
>>> from datasets import load_dataset

>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]

>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
>>> model = TFSwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")

>>> inputs = image_processor(image, return_tensors="tf")
>>> logits = model(**inputs).logits

>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_label = int(tf.math.argmax(logits, axis=-1))
>>> print(model.config.id2label[predicted_label])
tabby, tabby cat