PanoVLM_dev0 / mask2former.py
zhouyik's picture
Upload folder using huggingface_hub
4ee9c8f verified
import torch
import torch.nn as nn
import numpy as np
from typing import Dict, List, Optional, Tuple, Union
from transformers.models.mask2former.modeling_mask2former import (
Mask2FormerMaskedAttentionDecoderOutput, Mask2FormerModelOutput,
Mask2FormerForUniversalSegmentationOutput, Mask2FormerMLPPredictionHead,
sample_point, pair_wise_sigmoid_cross_entropy_loss, pair_wise_dice_loss,
sigmoid_cross_entropy_loss, dice_loss)
from torch import Tensor
import torch.nn.functional as F
from transformers.file_utils import is_scipy_available
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
def get_classification_logits(x, text_classifier, logit_scale):
# x in shape of [B, *, C]
# text_classifier in shape of [num_classes, C]
# logit_scale is a learnable scalar https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/model.py#L201
# return: [B, *, num_classes]
x = F.normalize(x, dim=-1)
text_classifier = F.normalize(text_classifier, dim=-1)
logit_scale = torch.clamp(logit_scale.exp(), max=100)
pred_logits = logit_scale * x @ text_classifier.T # B, *, N + 1
return pred_logits
def _post_init(self):
self.class_embed = Mask2FormerMLPPredictionHead(self.config.hidden_dim, self.config.hidden_dim, self.config.hidden_dim, 3)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def ov_class_predictor(self, x, text_classifier):
x = self.class_embed(x)
all_pred_logits = []
for per_x, per_text_classifier in zip(x, text_classifier):
per_pred_logits = get_classification_logits(per_x.unsqueeze(0), per_text_classifier, self.logit_scale)
all_pred_logits.append(per_pred_logits.squeeze(0))
return all_pred_logits
def Mask2FormerLoss_loss_labels(
self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array]
) -> Dict[str, Tensor]:
batch_size = len(class_queries_logits)
num_queries = class_queries_logits[0].shape[0]
all_ce_loss = []
for i in range(batch_size):
num_labels_plus1 = class_queries_logits[i].shape[-1]
empty_weight = torch.ones(num_labels_plus1)
empty_weight[-1] = self.eos_coef
empty_weight = empty_weight.to(class_queries_logits[i].device).to(class_queries_logits[i].dtype)
criterion = nn.CrossEntropyLoss(weight=empty_weight, reduction='none')
target_classes_o = class_labels[i][indices[i][1]]
target_classes = torch.full(
(num_queries, ), fill_value=num_labels_plus1-1, dtype=torch.int64, device=class_queries_logits[i].device)
target_classes[indices[i][0]] = target_classes_o.to(class_queries_logits[i].device)
target_classes = target_classes.unsqueeze(0)
pred_logits = class_queries_logits[i].unsqueeze(0).transpose(1, 2)
loss_ce = criterion(pred_logits, target_classes)
all_ce_loss.append(loss_ce)
losses = {"loss_cross_entropy": torch.cat(all_ce_loss, dim=-1).mean()}
return losses
def Mask2FormerLoss_loss_masks(
self,
masks_queries_logits: torch.Tensor,
mask_labels: List[torch.Tensor],
indices: Tuple[np.array],
num_masks: int
) -> Dict[str, torch.Tensor]:
src_idx = self._get_predictions_permutation_indices(indices)
tgt_idx = self._get_targets_permutation_indices(indices)
# shape (batch_size * num_queries, height, width)
pred_masks = masks_queries_logits[src_idx]
# shape (batch_size, num_queries, height, width)
# pad all and stack the targets to the num_labels dimension
target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
target_masks = target_masks[tgt_idx]
# No need to upsample predictions as we are using normalized coordinates
pred_masks = pred_masks[:, None]
target_masks = target_masks[:, None]
# Sample point coordinates
with torch.no_grad():
point_coordinates = self.sample_points_using_uncertainty(
pred_masks,
lambda logits: self.calculate_uncertainty(logits),
self.num_points,
self.oversample_ratio,
self.importance_sample_ratio,
)
point_labels = sample_point(target_masks.to(torch.bfloat16), point_coordinates.to(torch.bfloat16), align_corners=False).squeeze(1)
point_logits = sample_point(pred_masks, point_coordinates.to(pred_masks.dtype), align_corners=False).squeeze(1)
losses = {
"loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks),
"loss_dice": dice_loss(point_logits, point_labels, num_masks),
}
del pred_masks
del target_masks
return losses
def Mask2FormerLoss_sample_points_using_uncertainty(
self,
logits: torch.Tensor,
uncertainty_function,
num_points: int,
oversample_ratio: int,
importance_sample_ratio: float,
) -> torch.Tensor:
num_boxes = logits.shape[0]
num_points_sampled = int(num_points * oversample_ratio)
# Get random point coordinates
point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
# Get sampled prediction value for the point coordinates
point_logits = sample_point(logits, point_coordinates.to(logits.dtype), align_corners=False)
# Calculate the uncertainties based on the sampled prediction values of the points
point_uncertainties = uncertainty_function(point_logits)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)
idx += shift[:, None]
point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
if num_random_points > 0:
point_coordinates = torch.cat(
[point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],
dim=1,
)
return point_coordinates
@torch.no_grad()
def Mask2FormerHungarianMatcher_forward(
self,
masks_queries_logits: torch.Tensor,
class_queries_logits: torch.Tensor,
mask_labels: torch.Tensor,
class_labels: torch.Tensor,
) -> List[Tuple[Tensor]]:
indices: List[Tuple[np.array]] = []
# iterate through batch size
batch_size = masks_queries_logits.shape[0]
for i in range(batch_size):
pred_probs = class_queries_logits[i].softmax(-1)
pred_mask = masks_queries_logits[i]
# Compute the classification cost. Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. The 1 is a constant that doesn't change the matching, it can be ommitted.
cost_class = -pred_probs[:, class_labels[i]]
target_mask = mask_labels[i].to(pred_mask)
target_mask = target_mask[:, None]
pred_mask = pred_mask[:, None]
# Sample ground truth and predicted masks
point_coordinates = torch.rand(1, self.num_points, 2, device=pred_mask.device)
target_coordinates = point_coordinates.repeat(target_mask.shape[0], 1, 1).to(target_mask.dtype)
target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1)
pred_coordinates = point_coordinates.repeat(pred_mask.shape[0], 1, 1).to(pred_mask.dtype)
pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1)
# compute the cross entropy loss between each mask pairs -> shape (num_queries, num_labels)
cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask)
# Compute the dice loss betwen each mask pairs -> shape (num_queries, num_labels)
cost_dice = pair_wise_dice_loss(pred_mask, target_mask)
# final cost matrix
cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
# eliminate infinite values in cost_matrix to avoid the error ``ValueError: cost matrix is infeasible``
cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10))
cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10))
cost_matrix = torch.nan_to_num(cost_matrix, 0)
# do the assigmented using the hungarian algorithm in scipy
assigned_indices: Tuple[np.array] = linear_sum_assignment(cost_matrix.to(torch.float32).cpu())
indices.append(assigned_indices)
# It could be stacked in one tensor
matched_indices = [
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices
]
return matched_indices
def Mask2FormerMaskedAttentionDecoder_forward_first3layers(
self,
inputs_embeds: torch.Tensor = None,
multi_stage_positional_embeddings: torch.Tensor = None,
pixel_embeddings: torch.Tensor = None,
encoder_hidden_states: torch.Tensor = None,
query_position_embeddings: torch.Tensor = None,
feature_size_list: List = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`):
The query embeddings that are passed into the decoder.
multi_stage_positional_embeddings (`torch.FloatTensor` of shape `(height*width, batch_size, num_channels)`):
Position embeddings that are added to the keys in each cross(masked)-attention layer.
pixel_embeddings (`torch.FloatTensor`):
Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel
Decoder.
query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`):
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross(masked)-attention of the decoder.
feature_size_list (`List[torch.Size]`):
This is a list containing shapes (height & width) of multi-scale features from the Pixel Decoder.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is not None:
hidden_states = inputs_embeds
# intermediate hidden states with layernorm applied - required for predicting class logits
intermediate = ()
# decoder layers
all_hidden_states = () if output_hidden_states else None
attentions = () if output_attentions else None
# intermediate mask predictions from transformer decoder layers
intermediate_mask_predictions = ()
intermediate_hidden_states = self.layernorm(inputs_embeds)
intermediate += (intermediate_hidden_states,)
predicted_mask, attention_mask = self.mask_predictor(
intermediate_hidden_states, pixel_embeddings, feature_size_list[0]
)
intermediate_mask_predictions += (predicted_mask,)
for idx, decoder_layer in enumerate(self.layers[:3]):
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([])
if self.training and (dropout_probability < self.layerdrop):
continue
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
None,
None,
output_attentions,
)
else:
level_index = idx % self.num_feature_levels
where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype)
# Multiply the attention mask instead of indexing to avoid issue in torch.export.
attention_mask = attention_mask * where.unsqueeze(-1)
layer_outputs = decoder_layer(
hidden_states,
level_index=level_index,
position_embeddings=multi_stage_positional_embeddings,
query_position_embeddings=query_position_embeddings,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
output_attentions=output_attentions,
)
intermediate_hidden_states = self.layernorm(layer_outputs[0])
predicted_mask, attention_mask = self.mask_predictor(
intermediate_hidden_states,
pixel_embeddings,
feature_size_list[(idx + 1) % self.num_feature_levels],
)
intermediate_mask_predictions += (predicted_mask,)
# add intermediate hidden states with layer norm applied which will be used for predicting class logits
intermediate += (intermediate_hidden_states,)
hidden_states = layer_outputs[0]
if output_attentions:
attentions += (layer_outputs[1],)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
hidden_states = hidden_states.transpose(1, 0)
if not return_dict:
outputs = [hidden_states, all_hidden_states, attentions, intermediate, intermediate_mask_predictions]
return tuple(v for v in outputs if v is not None)
return Mask2FormerMaskedAttentionDecoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=attentions,
intermediate_hidden_states=intermediate,
masks_queries_logits=intermediate_mask_predictions,
)
def Mask2FormerMaskedAttentionDecoder_forward_last3layers(
self,
inputs_embeds: torch.Tensor = None,
multi_stage_positional_embeddings: torch.Tensor = None,
pixel_embeddings: torch.Tensor = None,
encoder_hidden_states: torch.Tensor = None,
query_position_embeddings: torch.Tensor = None,
feature_size_list: List = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`):
The query embeddings that are passed into the decoder.
multi_stage_positional_embeddings (`torch.FloatTensor` of shape `(height*width, batch_size, num_channels)`):
Position embeddings that are added to the keys in each cross(masked)-attention layer.
pixel_embeddings (`torch.FloatTensor`):
Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel
Decoder.
query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`):
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross(masked)-attention of the decoder.
feature_size_list (`List[torch.Size]`):
This is a list containing shapes (height & width) of multi-scale features from the Pixel Decoder.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is not None:
hidden_states = inputs_embeds
# intermediate hidden states with layernorm applied - required for predicting class logits
intermediate = ()
# decoder layers
all_hidden_states = () if output_hidden_states else None
attentions = () if output_attentions else None
# intermediate mask predictions from transformer decoder layers
intermediate_mask_predictions = ()
intermediate_hidden_states = self.layernorm(inputs_embeds)
intermediate += (intermediate_hidden_states,)
predicted_mask, attention_mask = self.mask_predictor(
intermediate_hidden_states, pixel_embeddings, feature_size_list[0]
)
intermediate_mask_predictions += (predicted_mask,)
for _idx, decoder_layer in enumerate(self.layers[3:]):
idx = _idx + 3
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([])
if self.training and (dropout_probability < self.layerdrop):
continue
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
None,
None,
output_attentions,
)
else:
level_index = idx % self.num_feature_levels
where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype)
# Multiply the attention mask instead of indexing to avoid issue in torch.export.
attention_mask = attention_mask * where.unsqueeze(-1)
layer_outputs = decoder_layer(
hidden_states,
level_index=level_index,
position_embeddings=multi_stage_positional_embeddings,
query_position_embeddings=query_position_embeddings,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
output_attentions=output_attentions,
)
intermediate_hidden_states = self.layernorm(layer_outputs[0])
predicted_mask, attention_mask = self.mask_predictor(
intermediate_hidden_states,
pixel_embeddings,
feature_size_list[(idx + 1) % self.num_feature_levels],
)
intermediate_mask_predictions += (predicted_mask,)
# add intermediate hidden states with layer norm applied which will be used for predicting class logits
intermediate += (intermediate_hidden_states,)
hidden_states = layer_outputs[0]
if output_attentions:
attentions += (layer_outputs[1],)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
hidden_states = hidden_states.transpose(1, 0)
if not return_dict:
outputs = [hidden_states, all_hidden_states, attentions, intermediate, intermediate_mask_predictions]
return tuple(v for v in outputs if v is not None)
return Mask2FormerMaskedAttentionDecoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=attentions,
intermediate_hidden_states=intermediate,
masks_queries_logits=intermediate_mask_predictions,
)
def Mask2FormerTransformerModule_forward_first_part(
self,
multi_scale_features: List[Tensor],
mask_features: Tensor,
output_hidden_states: bool = False,
output_attentions: bool = False,
) -> Mask2FormerMaskedAttentionDecoderOutput:
multi_stage_features = []
multi_stage_positional_embeddings = []
size_list = []
for i in range(self.num_feature_levels):
size_list.append(multi_scale_features[i].shape[-2:])
multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2))
multi_stage_features.append(
self.input_projections[i](multi_scale_features[i]).flatten(2)
+ self.level_embed.weight[i][None, :, None]
)
# Flatten (batch_size, num_channels, height, width) -> (height*width, batch_size, num_channels)
multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1)
multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1)
_, batch_size, _ = multi_stage_features[0].shape
# [num_queries, batch_size, num_channels]
query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1)
query_features = self.queries_features.weight.unsqueeze(1).repeat(1, batch_size, 1)
decoder_output = self.decoder.Mask2FormerMaskedAttentionDecoder_forward_first3layers(
inputs_embeds=query_features,
multi_stage_positional_embeddings=multi_stage_positional_embeddings,
pixel_embeddings=mask_features,
encoder_hidden_states=multi_stage_features,
query_position_embeddings=query_embeddings,
feature_size_list=size_list,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=True,
)
return decoder_output
def Mask2FormerTransformerModule_forward_second_part(
self,
query_features: Tensor,
query_embeddings: Tensor,
multi_scale_features: List[Tensor],
mask_features: Tensor,
output_hidden_states: bool = False,
output_attentions: bool = False,
) -> Mask2FormerMaskedAttentionDecoderOutput:
multi_stage_features = []
multi_stage_positional_embeddings = []
size_list = []
for i in range(self.num_feature_levels):
size_list.append(multi_scale_features[i].shape[-2:])
multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2))
multi_stage_features.append(
self.input_projections[i](multi_scale_features[i]).flatten(2)
+ self.level_embed.weight[i][None, :, None]
)
# Flatten (batch_size, num_channels, height, width) -> (height*width, batch_size, num_channels)
multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1)
multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1)
_, batch_size, _ = multi_stage_features[0].shape
# [num_queries, batch_size, num_channels]
# query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1)
# query_features = self.queries_features.weight.unsqueeze(1).repeat(1, batch_size, 1)
decoder_output = self.decoder.Mask2FormerMaskedAttentionDecoder_forward_last3layers(
inputs_embeds=query_features,
multi_stage_positional_embeddings=multi_stage_positional_embeddings,
pixel_embeddings=mask_features,
encoder_hidden_states=multi_stage_features,
query_position_embeddings=query_embeddings,
feature_size_list=size_list,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=True,
)
return decoder_output
def Mask2FormerModel_forward_first_part(
self,
pixel_values: Tensor,
pixel_mask: Optional[Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Mask2FormerModelOutput:
r"""
Returns:
`Mask2FormerModelOutput`
Examples:
```python
>>> import torch
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoImageProcessor, Mask2FormerModel
>>> # load image
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> # load image preprocessor and Mask2FormerModel trained on COCO instance segmentation dataset
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance")
>>> model = Mask2FormerModel.from_pretrained("facebook/mask2former-swin-small-coco-instance")
>>> inputs = image_processor(image, return_tensors="pt")
>>> # forward pass
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> # model outputs last hidden states of shape (batch_size, num_queries, hidden_size)
>>> print(outputs.transformer_decoder_last_hidden_state.shape)
torch.Size([1, 100, 256])
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, _, height, width = pixel_values.shape
if pixel_mask is None:
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device)
pixel_level_module_output = self.pixel_level_module(
pixel_values=pixel_values, output_hidden_states=output_hidden_states
)
transformer_module_output = self.transformer_module.Mask2FormerTransformerModule_forward_first_part(
multi_scale_features=pixel_level_module_output.decoder_hidden_states,
mask_features=pixel_level_module_output.decoder_last_hidden_state,
output_hidden_states=True,
output_attentions=output_attentions,
)
query_features = transformer_module_output.last_hidden_state
return query_features, pixel_level_module_output
def Mask2FormerModel_forward_second_part(
self,
query_features: Tensor,
query_embeddings: Tensor,
pixel_level_module_output,
pixel_values: Tensor,
pixel_mask: Optional[Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Mask2FormerModelOutput:
r"""
Returns:
`Mask2FormerModelOutput`
Examples:
```python
>>> import torch
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoImageProcessor, Mask2FormerModel
>>> # load image
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> # load image preprocessor and Mask2FormerModel trained on COCO instance segmentation dataset
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance")
>>> model = Mask2FormerModel.from_pretrained("facebook/mask2former-swin-small-coco-instance")
>>> inputs = image_processor(image, return_tensors="pt")
>>> # forward pass
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> # model outputs last hidden states of shape (batch_size, num_queries, hidden_size)
>>> print(outputs.transformer_decoder_last_hidden_state.shape)
torch.Size([1, 100, 256])
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, _, height, width = pixel_values.shape
if pixel_mask is None:
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device)
transformer_module_output = self.transformer_module.Mask2FormerTransformerModule_forward_second_part(
query_features=query_features,
query_embeddings=query_embeddings,
multi_scale_features=pixel_level_module_output.decoder_hidden_states,
mask_features=pixel_level_module_output.decoder_last_hidden_state,
output_hidden_states=True,
output_attentions=output_attentions,
)
encoder_hidden_states = None
pixel_decoder_hidden_states = None
transformer_decoder_hidden_states = None
transformer_decoder_intermediate_states = None
if output_hidden_states:
encoder_hidden_states = pixel_level_module_output.encoder_hidden_states
pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states
transformer_decoder_hidden_states = transformer_module_output.hidden_states
transformer_decoder_intermediate_states = transformer_module_output.intermediate_hidden_states
output = Mask2FormerModelOutput(
encoder_last_hidden_state=pixel_level_module_output.encoder_last_hidden_state,
pixel_decoder_last_hidden_state=pixel_level_module_output.decoder_last_hidden_state,
transformer_decoder_last_hidden_state=transformer_module_output.last_hidden_state,
encoder_hidden_states=encoder_hidden_states,
pixel_decoder_hidden_states=pixel_decoder_hidden_states,
transformer_decoder_hidden_states=transformer_decoder_hidden_states,
transformer_decoder_intermediate_states=transformer_decoder_intermediate_states,
attentions=transformer_module_output.attentions,
masks_queries_logits=transformer_module_output.masks_queries_logits,
)
if not return_dict:
output = tuple(v for v in output.values() if v is not None)
return output
def Mask2FormerForUniversalSegmentation_forward_first_part(
self,
pixel_values: Tensor,
mask_labels: Optional[List[Tensor]] = None,
class_labels: Optional[List[Tensor]] = None,
pixel_mask: Optional[Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_auxiliary_logits: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Mask2FormerForUniversalSegmentationOutput:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
query_features, pixel_level_module_output = self.model.Mask2FormerModel_forward_first_part(
pixel_values=pixel_values,
pixel_mask=pixel_mask,
output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss,
output_attentions=output_attentions,
return_dict=True,
)
return query_features, pixel_level_module_output
def Mask2FormerForUniversalSegmentation_forward_second_part(
self,
query_features,
query_embeddings,
pixel_level_module_output,
text_classifier,
pixel_values: Tensor,
mask_labels: Optional[List[Tensor]] = None,
class_labels: Optional[List[Tensor]] = None,
pixel_mask: Optional[Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_auxiliary_logits: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Mask2FormerForUniversalSegmentationOutput:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model.Mask2FormerModel_forward_second_part(
query_features=query_features,
query_embeddings=query_embeddings,
pixel_level_module_output=pixel_level_module_output,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss,
output_attentions=output_attentions,
return_dict=True,
)
loss, loss_dict, auxiliary_logits = None, None, None
class_queries_logits = ()
for decoder_output in outputs.transformer_decoder_intermediate_states:
class_prediction = self.ov_class_predictor(decoder_output.transpose(0, 1), text_classifier)
# class_prediction = self.class_predictor(decoder_output.transpose(0, 1))
class_queries_logits += (class_prediction,)
masks_queries_logits = outputs.masks_queries_logits
auxiliary_logits = self.get_auxiliary_logits(class_queries_logits, masks_queries_logits)
if mask_labels is not None and class_labels is not None:
loss_dict = self.get_loss_dict(
masks_queries_logits=masks_queries_logits[-1],
class_queries_logits=class_queries_logits[-1],
mask_labels=mask_labels,
class_labels=class_labels,
auxiliary_predictions=auxiliary_logits,
)
loss = self.get_loss(loss_dict)
encoder_hidden_states = None
pixel_decoder_hidden_states = None
transformer_decoder_hidden_states = None
if output_hidden_states:
encoder_hidden_states = outputs.encoder_hidden_states
pixel_decoder_hidden_states = outputs.pixel_decoder_hidden_states
transformer_decoder_hidden_states = outputs.transformer_decoder_hidden_states
output_auxiliary_logits = (
self.config.output_auxiliary_logits if output_auxiliary_logits is None else output_auxiliary_logits
)
if not output_auxiliary_logits:
auxiliary_logits = None
output = Mask2FormerForUniversalSegmentationOutput(
loss=loss,
class_queries_logits=class_queries_logits[-1],
masks_queries_logits=masks_queries_logits[-1],
auxiliary_logits=auxiliary_logits,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
pixel_decoder_last_hidden_state=outputs.pixel_decoder_last_hidden_state,
transformer_decoder_last_hidden_state=outputs.transformer_decoder_last_hidden_state,
encoder_hidden_states=encoder_hidden_states,
pixel_decoder_hidden_states=pixel_decoder_hidden_states,
transformer_decoder_hidden_states=transformer_decoder_hidden_states,
attentions=outputs.attentions,
)
if not return_dict:
output = tuple(v for v in output.values() if v is not None)
if loss is not None:
output = (loss) + output
return output