Transformers documentation

DPT

You are viewing v4.46.0 version. A newer version v4.46.3 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

DPT

Overview

The DPT model was proposed in Vision Transformers for Dense Prediction by RenΓ© Ranftl, Alexey Bochkovskiy, Vladlen Koltun. DPT is a model that leverages the Vision Transformer (ViT) as backbone for dense prediction tasks like semantic segmentation and depth estimation.

The abstract from the paper is the following:

We introduce dense vision transformers, an architecture that leverages vision transformers in place of convolutional networks as a backbone for dense prediction tasks. We assemble tokens from various stages of the vision transformer into image-like representations at various resolutions and progressively combine them into full-resolution predictions using a convolutional decoder. The transformer backbone processes representations at a constant and relatively high resolution and has a global receptive field at every stage. These properties allow the dense vision transformer to provide finer-grained and more globally coherent predictions when compared to fully-convolutional networks. Our experiments show that this architecture yields substantial improvements on dense prediction tasks, especially when a large amount of training data is available. For monocular depth estimation, we observe an improvement of up to 28% in relative performance when compared to a state-of-the-art fully-convolutional network. When applied to semantic segmentation, dense vision transformers set a new state of the art on ADE20K with 49.02% mIoU. We further show that the architecture can be fine-tuned on smaller datasets such as NYUv2, KITTI, and Pascal Context where it also sets the new state of the art.

drawing DPT architecture. Taken from the original paper.

This model was contributed by nielsr. The original code can be found here.

Usage tips

DPT is compatible with the AutoBackbone class. This allows to use the DPT framework with various computer vision backbones available in the library, such as VitDetBackbone or Dinov2Backbone. One can create it as follows:

from transformers import Dinov2Config, DPTConfig, DPTForDepthEstimation

# initialize with a Transformer-based backbone such as DINOv2
# in that case, we also specify `reshape_hidden_states=False` to get feature maps of shape (batch_size, num_channels, height, width)
backbone_config = Dinov2Config.from_pretrained("facebook/dinov2-base", out_features=["stage1", "stage2", "stage3", "stage4"], reshape_hidden_states=False)

config = DPTConfig(backbone_config=backbone_config)
model = DPTForDepthEstimation(config=config)

Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DPT.

If you’re interested in submitting a resource to be included here, please feel free to open a Pull Request and we’ll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

DPTConfig

class transformers.DPTConfig

< >

( hidden_size = 768 num_hidden_layers = 12 num_attention_heads = 12 intermediate_size = 3072 hidden_act = 'gelu' hidden_dropout_prob = 0.0 attention_probs_dropout_prob = 0.0 initializer_range = 0.02 layer_norm_eps = 1e-12 image_size = 384 patch_size = 16 num_channels = 3 is_hybrid = False qkv_bias = True backbone_out_indices = [2, 5, 8, 11] readout_type = 'project' reassemble_factors = [4, 2, 1, 0.5] neck_hidden_sizes = [96, 192, 384, 768] fusion_hidden_size = 256 head_in_index = -1 use_batch_norm_in_fusion_residual = False use_bias_in_fusion_residual = None add_projection = False use_auxiliary_head = True auxiliary_loss_weight = 0.4 semantic_loss_ignore_index = 255 semantic_classifier_dropout = 0.1 backbone_featmap_shape = [1, 1024, 24, 24] neck_ignore_stages = [0, 1] backbone_config = None backbone = None use_pretrained_backbone = False use_timm_backbone = False backbone_kwargs = None **kwargs )

Parameters

  • hidden_size (int, optional, defaults to 768) — Dimensionality of the encoder layers and the pooler layer.
  • num_hidden_layers (int, optional, defaults to 12) — Number of hidden layers in the Transformer encoder.
  • num_attention_heads (int, optional, defaults to 12) — Number of attention heads for each attention layer in the Transformer encoder.
  • intermediate_size (int, optional, defaults to 3072) — Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
  • hidden_act (str or function, optional, defaults to "gelu") — The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu", "relu", "selu" and "gelu_new" are supported.
  • hidden_dropout_prob (float, optional, defaults to 0.0) — The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
  • attention_probs_dropout_prob (float, optional, defaults to 0.0) — The dropout ratio for the attention probabilities.
  • initializer_range (float, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  • layer_norm_eps (float, optional, defaults to 1e-12) — The epsilon used by the layer normalization layers.
  • image_size (int, optional, defaults to 384) — The size (resolution) of each image.
  • patch_size (int, optional, defaults to 16) — The size (resolution) of each patch.
  • num_channels (int, optional, defaults to 3) — The number of input channels.
  • is_hybrid (bool, optional, defaults to False) — Whether to use a hybrid backbone. Useful in the context of loading DPT-Hybrid models.
  • qkv_bias (bool, optional, defaults to True) — Whether to add a bias to the queries, keys and values.
  • backbone_out_indices (List[int], optional, defaults to [2, 5, 8, 11]) — Indices of the intermediate hidden states to use from backbone.
  • readout_type (str, optional, defaults to "project") — The readout type to use when processing the readout token (CLS token) of the intermediate hidden states of the ViT backbone. Can be one of ["ignore", "add", "project"].

    • “ignore” simply ignores the CLS token.
    • “add” passes the information from the CLS token to all other tokens by adding the representations.
    • “project” passes information to the other tokens by concatenating the readout to all other tokens before projecting the representation to the original feature dimension D using a linear layer followed by a GELU non-linearity.
  • reassemble_factors (List[int], optional, defaults to [4, 2, 1, 0.5]) — The up/downsampling factors of the reassemble layers.
  • neck_hidden_sizes (List[str], optional, defaults to [96, 192, 384, 768]) — The hidden sizes to project to for the feature maps of the backbone.
  • fusion_hidden_size (int, optional, defaults to 256) — The number of channels before fusion.
  • head_in_index (int, optional, defaults to -1) — The index of the features to use in the heads.
  • use_batch_norm_in_fusion_residual (bool, optional, defaults to False) — Whether to use batch normalization in the pre-activate residual units of the fusion blocks.
  • use_bias_in_fusion_residual (bool, optional, defaults to True) — Whether to use bias in the pre-activate residual units of the fusion blocks.
  • add_projection (bool, optional, defaults to False) — Whether to add a projection layer before the depth estimation head.
  • use_auxiliary_head (bool, optional, defaults to True) — Whether to use an auxiliary head during training.
  • auxiliary_loss_weight (float, optional, defaults to 0.4) — Weight of the cross-entropy loss of the auxiliary head.
  • semantic_loss_ignore_index (int, optional, defaults to 255) — The index that is ignored by the loss function of the semantic segmentation model.
  • semantic_classifier_dropout (float, optional, defaults to 0.1) — The dropout ratio for the semantic classification head.
  • backbone_featmap_shape (List[int], optional, defaults to [1, 1024, 24, 24]) — Used only for the hybrid embedding type. The shape of the feature maps of the backbone.
  • neck_ignore_stages (List[int], optional, defaults to [0, 1]) — Used only for the hybrid embedding type. The stages of the readout layers to ignore.
  • backbone_config (Union[Dict[str, Any], PretrainedConfig], optional) — The configuration of the backbone model. Only used in case is_hybrid is True or in case you want to leverage the AutoBackbone API.
  • backbone (str, optional) — Name of backbone to use when backbone_config is None. If use_pretrained_backbone is True, this will load the corresponding pretrained weights from the timm or transformers library. If use_pretrained_backbone is False, this loads the backbone’s config and uses that to initialize the backbone with random weights.
  • use_pretrained_backbone (bool, optional, defaults to False) — Whether to use pretrained weights for the backbone.
  • use_timm_backbone (bool, optional, defaults to False) — Whether to load backbone from the timm library. If False, the backbone is loaded from the transformers library.
  • backbone_kwargs (dict, optional) — Keyword arguments to be passed to AutoBackbone when loading from a checkpoint e.g. {'out_indices': (0, 1, 2, 3)}. Cannot be specified if backbone_config is set.

This is the configuration class to store the configuration of a DPTModel. It is used to instantiate an DPT model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the DPT Intel/dpt-large architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

>>> from transformers import DPTModel, DPTConfig

>>> # Initializing a DPT dpt-large style configuration
>>> configuration = DPTConfig()

>>> # Initializing a model from the dpt-large style configuration
>>> model = DPTModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config

to_dict

< >

( )

Serializes this instance to a Python dictionary. Override the default to_dict(). Returns: Dict[str, any]: Dictionary of all the attributes that make up this configuration instance,

DPTFeatureExtractor

class transformers.DPTFeatureExtractor

< >

( *args **kwargs )

__call__

< >

( images **kwargs )

Preprocess an image or a batch of images.

post_process_semantic_segmentation

< >

( outputs target_sizes: List = None ) β†’ semantic_segmentation

Parameters

  • outputs (DPTForSemanticSegmentation) — Raw outputs of the model.
  • target_sizes (List[Tuple] of length batch_size, optional) — List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, predictions will not be resized.

Returns

semantic_segmentation

List[torch.Tensor] of length batch_size, where each item is a semantic segmentation map of shape (height, width) corresponding to the target_sizes entry (if target_sizes is specified). Each entry of each torch.Tensor correspond to a semantic class id.

Converts the output of DPTForSemanticSegmentation into semantic segmentation maps. Only supports PyTorch.

DPTImageProcessor

class transformers.DPTImageProcessor

< >

( do_resize: bool = True size: Dict = None resample: Resampling = <Resampling.BICUBIC: 3> keep_aspect_ratio: bool = False ensure_multiple_of: int = 1 do_rescale: bool = True rescale_factor: Union = 0.00392156862745098 do_normalize: bool = True image_mean: Union = None image_std: Union = None do_pad: bool = False size_divisor: int = None **kwargs )

Parameters

  • do_resize (bool, optional, defaults to True) — Whether to resize the image’s (height, width) dimensions. Can be overidden by do_resize in preprocess.
  • size (Dict[str, int] optional, defaults to {"height" -- 384, "width": 384}): Size of the image after resizing. Can be overidden by size in preprocess.
  • resample (PILImageResampling, optional, defaults to Resampling.BICUBIC) — Defines the resampling filter to use if resizing the image. Can be overidden by resample in preprocess.
  • keep_aspect_ratio (bool, optional, defaults to False) — If True, the image is resized to the largest possible size such that the aspect ratio is preserved. Can be overidden by keep_aspect_ratio in preprocess.
  • ensure_multiple_of (int, optional, defaults to 1) — If do_resize is True, the image is resized to a size that is a multiple of this value. Can be overidden by ensure_multiple_of in preprocess.
  • do_rescale (bool, optional, defaults to True) — Whether to rescale the image by the specified scale rescale_factor. Can be overidden by do_rescale in preprocess.
  • rescale_factor (int or float, optional, defaults to 1/255) — Scale factor to use if rescaling the image. Can be overidden by rescale_factor in preprocess.
  • do_normalize (bool, optional, defaults to True) — Whether to normalize the image. Can be overridden by the do_normalize parameter in the preprocess method.
  • image_mean (float or List[float], optional, defaults to IMAGENET_STANDARD_MEAN) — Mean to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the image_mean parameter in the preprocess method.
  • image_std (float or List[float], optional, defaults to IMAGENET_STANDARD_STD) — Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the image_std parameter in the preprocess method.
  • do_pad (bool, optional, defaults to False) — Whether to apply center padding. This was introduced in the DINOv2 paper, which uses the model in combination with DPT.
  • size_divisor (int, optional) — If do_pad is True, pads the image dimensions to be divisible by this value. This was introduced in the DINOv2 paper, which uses the model in combination with DPT.

Constructs a DPT image processor.

preprocess

< >

( images: Union do_resize: bool = None size: int = None keep_aspect_ratio: bool = None ensure_multiple_of: int = None resample: Resampling = None do_rescale: bool = None rescale_factor: float = None do_normalize: bool = None image_mean: Union = None image_std: Union = None do_pad: bool = None size_divisor: int = None return_tensors: Union = None data_format: ChannelDimension = <ChannelDimension.FIRST: 'channels_first'> input_data_format: Union = None )

Parameters

  • images (ImageInput) — Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set do_rescale=False.
  • do_resize (bool, optional, defaults to self.do_resize) — Whether to resize the image.
  • size (Dict[str, int], optional, defaults to self.size) — Size of the image after reszing. If keep_aspect_ratio is True, the image is resized to the largest possible size such that the aspect ratio is preserved. If ensure_multiple_of is set, the image is resized to a size that is a multiple of this value.
  • keep_aspect_ratio (bool, optional, defaults to self.keep_aspect_ratio) — Whether to keep the aspect ratio of the image. If False, the image will be resized to (size, size). If True, the image will be resized to keep the aspect ratio and the size will be the maximum possible.
  • ensure_multiple_of (int, optional, defaults to self.ensure_multiple_of) — Ensure that the image size is a multiple of this value.
  • resample (int, optional, defaults to self.resample) — Resampling filter to use if resizing the image. This can be one of the enum PILImageResampling, Only has an effect if do_resize is set to True.
  • do_rescale (bool, optional, defaults to self.do_rescale) — Whether to rescale the image values between [0 - 1].
  • rescale_factor (float, optional, defaults to self.rescale_factor) — Rescale factor to rescale the image by if do_rescale is set to True.
  • do_normalize (bool, optional, defaults to self.do_normalize) — Whether to normalize the image.
  • image_mean (float or List[float], optional, defaults to self.image_mean) — Image mean.
  • image_std (float or List[float], optional, defaults to self.image_std) — Image standard deviation.
  • return_tensors (str or TensorType, optional) — The type of tensors to return. Can be one of:
    • Unset: Return a list of np.ndarray.
    • TensorType.TENSORFLOW or 'tf': Return a batch of type tf.Tensor.
    • TensorType.PYTORCH or 'pt': Return a batch of type torch.Tensor.
    • TensorType.NUMPY or 'np': Return a batch of type np.ndarray.
    • TensorType.JAX or 'jax': Return a batch of type jax.numpy.ndarray.
  • data_format (ChannelDimension or str, optional, defaults to ChannelDimension.FIRST) — The channel dimension format for the output image. Can be one of:
    • ChannelDimension.FIRST: image in (num_channels, height, width) format.
    • ChannelDimension.LAST: image in (height, width, num_channels) format.
  • input_data_format (ChannelDimension or str, optional) — The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of:
    • "channels_first" or ChannelDimension.FIRST: image in (num_channels, height, width) format.
    • "channels_last" or ChannelDimension.LAST: image in (height, width, num_channels) format.
    • "none" or ChannelDimension.NONE: image in (height, width) format.

Preprocess an image or batch of images.

post_process_semantic_segmentation

< >

( outputs target_sizes: List = None ) β†’ semantic_segmentation

Parameters

  • outputs (DPTForSemanticSegmentation) — Raw outputs of the model.
  • target_sizes (List[Tuple] of length batch_size, optional) — List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, predictions will not be resized.

Returns

semantic_segmentation

List[torch.Tensor] of length batch_size, where each item is a semantic segmentation map of shape (height, width) corresponding to the target_sizes entry (if target_sizes is specified). Each entry of each torch.Tensor correspond to a semantic class id.

Converts the output of DPTForSemanticSegmentation into semantic segmentation maps. Only supports PyTorch.

DPTModel

class transformers.DPTModel

< >

( config add_pooling_layer = True )

Parameters

  • config (ViTConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

The bare DPT Model transformer outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( pixel_values: FloatTensor head_mask: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) β†’ transformers.models.dpt.modeling_dpt.BaseModelOutputWithPoolingAndIntermediateActivations or tuple(torch.FloatTensor)

Parameters

  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See DPTImageProcessor.call() for details.
  • head_mask (torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads), optional) — Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]:

    • 1 indicates the head is not masked,
    • 0 indicates the head is masked.
  • output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.

Returns

transformers.models.dpt.modeling_dpt.BaseModelOutputWithPoolingAndIntermediateActivations or tuple(torch.FloatTensor)

A transformers.models.dpt.modeling_dpt.BaseModelOutputWithPoolingAndIntermediateActivations or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (DPTConfig) and inputs.

  • last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)) β€” Sequence of hidden-states at the output of the last layer of the model.

  • pooler_output (torch.FloatTensor of shape (batch_size, hidden_size)) β€” Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns the classification token after processing through a linear layer and a tanh activation function. The linear layer weights are trained from the next sentence prediction (classification) objective during pretraining.

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) β€” Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • intermediate_activations (tuple(torch.FloatTensor), optional) β€” Intermediate activations that can be used to compute hidden states of the model at various layers.

The DPTModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> from transformers import AutoImageProcessor, DPTModel
>>> import torch
>>> from datasets import load_dataset

>>> dataset = load_dataset("huggingface/cats-image", trust_remote_code=True)
>>> image = dataset["test"]["image"][0]

>>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large")
>>> model = DPTModel.from_pretrained("Intel/dpt-large")

>>> inputs = image_processor(image, return_tensors="pt")

>>> with torch.no_grad():
...     outputs = model(**inputs)

>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
[1, 577, 1024]

DPTForDepthEstimation

class transformers.DPTForDepthEstimation

< >

( config )

Parameters

  • config (ViTConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

DPT Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.

This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( pixel_values: FloatTensor head_mask: Optional = None labels: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) β†’ transformers.modeling_outputs.DepthEstimatorOutput or tuple(torch.FloatTensor)

Parameters

  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See DPTImageProcessor.call() for details.
  • head_mask (torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads), optional) — Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]:

    • 1 indicates the head is not masked,
    • 0 indicates the head is masked.
  • output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
  • labels (torch.LongTensor of shape (batch_size, height, width), optional) — Ground truth depth estimation maps for computing the loss.

Returns

transformers.modeling_outputs.DepthEstimatorOutput or tuple(torch.FloatTensor)

A transformers.modeling_outputs.DepthEstimatorOutput or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (DPTConfig) and inputs.

  • loss (torch.FloatTensor of shape (1,), optional, returned when labels is provided) β€” Classification (or regression if config.num_labels==1) loss.

  • predicted_depth (torch.FloatTensor of shape (batch_size, height, width)) β€” Predicted depth for each pixel.

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, num_channels, height, width).

    Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) β€” Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, patch_size, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

The DPTForDepthEstimation forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

>>> from transformers import AutoImageProcessor, DPTForDepthEstimation
>>> import torch
>>> import numpy as np
>>> from PIL import Image
>>> import requests

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large")
>>> model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")

>>> # prepare image for the model
>>> inputs = image_processor(images=image, return_tensors="pt")

>>> with torch.no_grad():
...     outputs = model(**inputs)

>>> # interpolate to original size
>>> post_processed_output = image_processor.post_process_depth_estimation(
...     outputs,
...     target_sizes=[(image.height, image.width)],
... )

>>> # visualize the prediction
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
>>> depth = predicted_depth * 255 / predicted_depth.max()
>>> depth = depth.detach().cpu().numpy()
>>> depth = Image.fromarray(depth.astype("uint8"))

DPTForSemanticSegmentation

class transformers.DPTForSemanticSegmentation

< >

( config )

Parameters

  • config (ViTConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

DPT Model with a semantic segmentation head on top e.g. for ADE20k, CityScapes.

This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( pixel_values: Optional = None head_mask: Optional = None labels: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) β†’ transformers.modeling_outputs.SemanticSegmenterOutput or tuple(torch.FloatTensor)

Parameters

  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See DPTImageProcessor.call() for details.
  • head_mask (torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads), optional) — Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]:

    • 1 indicates the head is not masked,
    • 0 indicates the head is masked.
  • output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
  • labels (torch.LongTensor of shape (batch_size, height, width), optional) — Ground truth semantic segmentation maps for computing the loss. Indices should be in [0, ..., config.num_labels - 1]. If config.num_labels > 1, a classification loss is computed (Cross-Entropy).

Returns

transformers.modeling_outputs.SemanticSegmenterOutput or tuple(torch.FloatTensor)

A transformers.modeling_outputs.SemanticSegmenterOutput or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (DPTConfig) and inputs.

  • loss (torch.FloatTensor of shape (1,), optional, returned when labels is provided) β€” Classification (or regression if config.num_labels==1) loss.

  • logits (torch.FloatTensor of shape (batch_size, config.num_labels, logits_height, logits_width)) β€” Classification scores for each pixel.

    The logits returned do not necessarily have the same size as the pixel_values passed as inputs. This is to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the original image size as post-processing. You should always check your logits shape and resize as needed.

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, patch_size, hidden_size).

    Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) β€” Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, patch_size, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

The DPTForSemanticSegmentation forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

>>> from transformers import AutoImageProcessor, DPTForSemanticSegmentation
>>> from PIL import Image
>>> import requests

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large-ade")
>>> model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade")

>>> inputs = image_processor(images=image, return_tensors="pt")

>>> outputs = model(**inputs)
>>> logits = outputs.logits
< > Update on GitHub