SegFormerΒΆ

OverviewΒΆ

The SegFormer model was proposed in SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. The model consists of a hierarchical Transformer encoder and a lightweight all-MLP decode head to achieve great results on image segmentation benchmarks such as ADE20K and Cityscapes.

The abstract from the paper is the following:

We present SegFormer, a simple, efficient yet powerful semantic segmentation framework which unifies Transformers with lightweight multilayer perception (MLP) decoders. SegFormer has two appealing features: 1) SegFormer comprises a novel hierarchically structured Transformer encoder which outputs multiscale features. It does not need positional encoding, thereby avoiding the interpolation of positional codes which leads to decreased performance when the testing resolution differs from training. 2) SegFormer avoids complex decoders. The proposed MLP decoder aggregates information from different layers, and thus combining both local attention and global attention to render powerful representations. We show that this simple and lightweight design is the key to efficient segmentation on Transformers. We scale our approach up to obtain a series of models from SegFormer-B0 to SegFormer-B5, reaching significantly better performance and efficiency than previous counterparts. For example, SegFormer-B4 achieves 50.3% mIoU on ADE20K with 64M parameters, being 5x smaller and 2.2% better than the previous best method. Our best model, SegFormer-B5, achieves 84.0% mIoU on Cityscapes validation set and shows excellent zero-shot robustness on Cityscapes-C.

This model was contributed by nielsr. The original code can be found here.

SegformerConfigΒΆ

class transformers.SegformerConfig(image_size=224, num_channels=3, num_encoder_blocks=4, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], hidden_sizes=[32, 64, 160, 256], downsampling_rates=[1, 4, 8, 16], patch_sizes=[7, 3, 3, 3], strides=[4, 2, 2, 2], num_attention_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], hidden_act='gelu', hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, classifier_dropout_prob=0.1, initializer_range=0.02, drop_path_rate=0.1, layer_norm_eps=1e-06, decoder_hidden_size=256, is_encoder_decoder=False, reshape_last_stage=True, **kwargs)[source]ΒΆ

This is the configuration class to store the configuration of a SegformerModel. It is used to instantiate an SegFormer model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the SegFormer nvidia/segformer-b0-finetuned-ade-512-512 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Parameters
  • image_size (int, optional, defaults to 512) – The size (resolution) of each image.

  • num_channels (int, optional, defaults to 3) – The number of input channels.

  • num_encoder_blocks (int, optional, defaults to 4) – The number of encoder blocks (i.e. stages in the Mix Transformer encoder).

  • depths (List[int], optional, defaults to [2, 2, 2, 2]) – The number of layers in each encoder block.

  • sr_ratios (List[int], optional, defaults to [8, 4, 2, 1]) – Sequence reduction ratios in each encoder block.

  • hidden_sizes (List[int], optional, defaults to [32, 64, 160, 256]) – Dimension of each of the encoder blocks.

  • downsampling_rates (List[int], optional, defaults to [1, 4, 8, 16]) – Downsample rate of the image resolution compared to the original image size before each encoder block.

  • patch_sizes (List[int], optional, defaults to [7, 3, 3, 3]) – Patch size before each encoder block.

  • strides (List[int], optional, defaults to [4, 2, 2, 2]) – Stride before each encoder block.

  • num_attention_heads (List[int], optional, defaults to [1, 2, 4, 8]) – Number of attention heads for each attention layer in each block of the Transformer encoder.

  • mlp_ratios (List[int], optional, defaults to [4, 4, 4, 4]) – Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the encoder blocks.

  • hidden_act (str or function, optional, defaults to "gelu") – The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu", "relu", "selu" and "gelu_new" are supported.

  • hidden_dropout_prob (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • attention_probs_dropout_prob (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • classifier_dropout_prob (float, optional, defaults to 0.1) – The dropout probability before the classification head.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • drop_path_rate (float, optional, defaults to 0.1) – The dropout probability for stochastic depth, used in the blocks of the Transformer encoder.

  • layer_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the layer normalization layers.

  • decoder_hidden_size (int, optional, defaults to 256) – The dimension of the all-MLP decode head.

  • reshape_last_stage (bool, optional, defaults to True) – Whether to reshape the features of the last stage back to (batch_size, num_channels, height, width). Only required for the semantic segmentation model.

Example:

>>> from transformers import SegformerModel, SegformerConfig

>>> # Initializing a SegFormer nvidia/segformer-b0-finetuned-ade-512-512 style configuration
>>> configuration = SegformerConfig()

>>> # Initializing a model from the nvidia/segformer-b0-finetuned-ade-512-512 style configuration
>>> model = SegformerModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config

SegformerFeatureExtractorΒΆ

class transformers.SegformerFeatureExtractor(do_resize=True, keep_ratio=True, image_scale=2048, 512, align=True, size_divisor=32, resample=2, do_random_crop=True, crop_size=512, 512, do_normalize=True, image_mean=None, image_std=None, do_pad=True, padding_value=0, segmentation_padding_value=255, reduce_zero_label=False, **kwargs)[source]ΒΆ

Constructs a SegFormer feature extractor.

This feature extractor inherits from FeatureExtractionMixin which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

Parameters
  • do_resize (bool, optional, defaults to True) – Whether to resize/rescale the input based on a certain image_scale.

  • keep_ratio (bool, optional, defaults to True) – Whether to keep the aspect ratio when resizing the input. Only has an effect if do_resize is set to True.

  • image_scale (float or int or Tuple[int]/List[int], optional, defaults to (2048, 512)) –

    In case keep_ratio is set to True, the scaling factor or maximum size. If it is a float number, then the image will be rescaled by this factor, else if it is a tuple/list of 2 integers (width, height), then the image will be rescaled as large as possible within the scale. In case keep_ratio is set to False, the target size (width, height) to which the image will be resized. If only an integer is provided, then the input will be resized to (size, size).

    Only has an effect if do_resize is set to True.

  • align (bool, optional, defaults to True) – Whether to ensure the long and short sides are divisible by size_divisor. Only has an effect if do_resize and keep_ratio are set to True.

  • size_divisor (int, optional, defaults to 32) – The integer by which both sides of an image should be divisible. Only has an effect if do_resize and align are set to True.

  • resample (int, optional, defaults to PIL.Image.BILINEAR) – An optional resampling filter. This can be one of PIL.Image.NEAREST, PIL.Image.BOX, PIL.Image.BILINEAR, PIL.Image.HAMMING, PIL.Image.BICUBIC or PIL.Image.LANCZOS. Only has an effect if do_resize is set to True.

  • do_random_crop (bool, optional, defaults to True) – Whether or not to randomly crop the input to a certain obj:crop_size.

  • crop_size (Tuple[int]/List[int], optional, defaults to (512, 512)) – The crop size to use, as a tuple (width, height). Only has an effect if do_random_crop is set to True.

  • do_normalize (bool, optional, defaults to True) – Whether or not to normalize the input with mean and standard deviation.

  • image_mean (int, optional, defaults to [0.485, 0.456, 0.406]) – The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean.

  • image_std (int, optional, defaults to [0.229, 0.224, 0.225]) – The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the ImageNet std.

  • do_pad (bool, optional, defaults to True) – Whether or not to pad the input to crop_size. Note that padding should only be applied in combination with random cropping.

  • padding_value (int, optional, defaults to 0) – Fill value for padding images.

  • segmentation_padding_value (int, optional, defaults to 255) – Fill value for padding segmentation maps. One must make sure the ignore_index of the CrossEntropyLoss is set equal to this value.

  • reduce_zero_label (bool, optional, defaults to False) – Whether or not to reduce all label values by 1. Usually used for datasets where 0 is the background label.

__call__(images: Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, List[PIL.Image.Image], List[numpy.ndarray], List[torch.Tensor]], segmentation_maps: Union[PIL.Image.Image, numpy.ndarray, List[PIL.Image.Image], List[numpy.ndarray]] = None, return_tensors: Optional[Union[str, transformers.file_utils.TensorType]] = None, **kwargs) → transformers.feature_extraction_utils.BatchFeature[source]ΒΆ

Main method to prepare for the model one or several image(s) and optional corresponding segmentation maps.

Warning

NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass PIL images.

Parameters
  • images (PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray], List[torch.Tensor]) – The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is the number of channels, H and W are image height and width.

  • segmentation_maps (PIL.Image.Image, np.ndarray, List[PIL.Image.Image], List[np.ndarray], optional) – Optionally, the corresponding semantic segmentation maps with the pixel-wise annotations.

  • return_tensors (str or TensorType, optional, defaults to 'np') –

    If set, will return tensors of a particular framework. Acceptable values are:

    • 'tf': Return TensorFlow tf.constant objects.

    • 'pt': Return PyTorch torch.Tensor objects.

    • 'np': Return NumPy np.ndarray objects.

    • 'jax': Return JAX jnp.ndarray objects.

Returns

A BatchFeature with the following fields:

  • pixel_values – Pixel values to be fed to a model, of shape (batch_size, num_channels, height, width).

  • labels – Optional labels to be fed to a model (when segmentation_maps are provided)

Return type

BatchFeature

SegformerModelΒΆ

class transformers.SegformerModel(config)[source]ΒΆ

The bare SegFormer encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Parameters

config (SegformerConfig) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

forward(pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None)[source]ΒΆ

The SegformerModel forward method, overrides the __call__() special method.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Parameters
  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)) – Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using SegformerFeatureExtractor. See transformers.SegformerFeatureExtractor.__call__() for details.

  • output_attentions (bool, optional) – Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.

  • output_hidden_states (bool, optional) – Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.

  • return_dict (bool, optional) – Whether or not to return a ModelOutput instead of a plain tuple.

Returns

A BaseModelOutput or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (SegformerConfig) and inputs.

  • last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)) – Sequence of hidden-states at the output of the last layer of the model.

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) – Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) – Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

Examples:

>>> from transformers import SegformerFeatureExtractor, SegformerModel
>>> from PIL import Image
>>> import requests

>>> feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0")
>>> model = SegformerModel("nvidia/segformer-b0")

>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> sequence_output = outputs.last_hidden_state

Return type

BaseModelOutput or tuple(torch.FloatTensor)

SegformerDecodeHeadΒΆ

class transformers.SegformerDecodeHead(config)[source]ΒΆ
forward(encoder_hidden_states)[source]ΒΆ

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

SegformerForImageClassificationΒΆ

class transformers.SegformerForImageClassification(config)[source]ΒΆ

SegFormer Model transformer with an image classification head on top (a linear layer on top of the final hidden states) e.g. for ImageNet.

This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Parameters

config (SegformerConfig) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

forward(pixel_values=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None)[source]ΒΆ

The SegformerForImageClassification forward method, overrides the __call__() special method.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Parameters
  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)) – Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using SegformerFeatureExtractor. See transformers.SegformerFeatureExtractor.__call__() for details.

  • output_attentions (bool, optional) – Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.

  • output_hidden_states (bool, optional) – Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.

  • return_dict (bool, optional) – Whether or not to return a ModelOutput instead of a plain tuple.

  • labels (torch.LongTensor of shape (batch_size,), optional) – Labels for computing the image classification/regression loss. Indices should be in [0, ..., config.num_labels - 1]. If config.num_labels == 1 a regression loss is computed (Mean-Square loss), If config.num_labels > 1 a classification loss is computed (Cross-Entropy).

Returns

A SequenceClassifierOutput or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (SegformerConfig) and inputs.

  • loss (torch.FloatTensor of shape (1,), optional, returned when labels is provided) – Classification (or regression if config.num_labels==1) loss.

  • logits (torch.FloatTensor of shape (batch_size, config.num_labels)) – Classification (or regression if config.num_labels==1) scores (before SoftMax).

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) – Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) – Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

Examples:

>>> from transformers import SegformerFeatureExtractor, SegformerForImageClassification
>>> from PIL import Image
>>> import requests

>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> feature_extractor = SegformerFeatureExtractor.from_pretrained('nvidia/mit-b0')
>>> model = SegformerForImageClassification.from_pretrained('nvidia/mit-b0')

>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])

Return type

SequenceClassifierOutput or tuple(torch.FloatTensor)

SegformerForSemanticSegmentationΒΆ

class transformers.SegformerForSemanticSegmentation(config)[source]ΒΆ

SegFormer Model transformer with an all-MLP decode head on top e.g. for ADE20k, CityScapes. This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Parameters

config (SegformerConfig) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

forward(pixel_values, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None)[source]ΒΆ

The SegformerForSemanticSegmentation forward method, overrides the __call__() special method.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Parameters
  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)) – Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using SegformerFeatureExtractor. See transformers.SegformerFeatureExtractor.__call__() for details.

  • output_attentions (bool, optional) – Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.

  • output_hidden_states (bool, optional) – Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.

  • return_dict (bool, optional) – Whether or not to return a ModelOutput instead of a plain tuple.

  • labels (torch.LongTensor of shape (batch_size, height, width), optional) – Ground truth semantic segmentation maps for computing the loss. Indices should be in [0, ..., config.num_labels - 1]. If config.num_labels > 1, a classification loss is computed (Cross-Entropy).

Returns

A SequenceClassifierOutput or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (SegformerConfig) and inputs.

  • loss (torch.FloatTensor of shape (1,), optional, returned when labels is provided) – Classification (or regression if config.num_labels==1) loss.

  • logits (torch.FloatTensor of shape (batch_size, config.num_labels)) – Classification (or regression if config.num_labels==1) scores (before SoftMax).

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) – Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) – Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

Examples:

>>> from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
>>> from PIL import Image
>>> import requests

>>> feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
>>> model = SegformerForSemanticSegmentation("nvidia/segformer-b0-finetuned-ade-512-512")

>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)

Return type

SequenceClassifierOutput or tuple(torch.FloatTensor)