|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from typing import Dict, List, Optional, Tuple, Union |
|
from transformers.models.mask2former.modeling_mask2former import ( |
|
Mask2FormerMaskedAttentionDecoderOutput, Mask2FormerModelOutput, |
|
Mask2FormerForUniversalSegmentationOutput, Mask2FormerMLPPredictionHead, |
|
sample_point, pair_wise_sigmoid_cross_entropy_loss, pair_wise_dice_loss, |
|
sigmoid_cross_entropy_loss, dice_loss) |
|
from torch import Tensor |
|
import torch.nn.functional as F |
|
|
|
from transformers.file_utils import is_scipy_available |
|
|
|
if is_scipy_available(): |
|
from scipy.optimize import linear_sum_assignment |
|
|
|
|
|
def get_classification_logits(x, text_classifier, logit_scale): |
|
|
|
|
|
|
|
|
|
x = F.normalize(x, dim=-1) |
|
text_classifier = F.normalize(text_classifier, dim=-1) |
|
logit_scale = torch.clamp(logit_scale.exp(), max=100) |
|
pred_logits = logit_scale * x @ text_classifier.T |
|
return pred_logits |
|
|
|
|
|
def _post_init(self): |
|
self.class_embed = Mask2FormerMLPPredictionHead(self.config.hidden_dim, self.config.hidden_dim, self.config.hidden_dim, 3) |
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
|
|
|
def ov_class_predictor(self, x, text_classifier): |
|
x = self.class_embed(x) |
|
all_pred_logits = [] |
|
for per_x, per_text_classifier in zip(x, text_classifier): |
|
per_pred_logits = get_classification_logits(per_x.unsqueeze(0), per_text_classifier, self.logit_scale) |
|
all_pred_logits.append(per_pred_logits.squeeze(0)) |
|
|
|
return all_pred_logits |
|
|
|
|
|
|
|
def Mask2FormerLoss_loss_labels( |
|
self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array] |
|
) -> Dict[str, Tensor]: |
|
batch_size = len(class_queries_logits) |
|
num_queries = class_queries_logits[0].shape[0] |
|
all_ce_loss = [] |
|
for i in range(batch_size): |
|
num_labels_plus1 = class_queries_logits[i].shape[-1] |
|
empty_weight = torch.ones(num_labels_plus1) |
|
empty_weight[-1] = self.eos_coef |
|
empty_weight = empty_weight.to(class_queries_logits[i].device).to(class_queries_logits[i].dtype) |
|
criterion = nn.CrossEntropyLoss(weight=empty_weight, reduction='none') |
|
target_classes_o = class_labels[i][indices[i][1]] |
|
target_classes = torch.full( |
|
(num_queries, ), fill_value=num_labels_plus1-1, dtype=torch.int64, device=class_queries_logits[i].device) |
|
target_classes[indices[i][0]] = target_classes_o.to(class_queries_logits[i].device) |
|
target_classes = target_classes.unsqueeze(0) |
|
pred_logits = class_queries_logits[i].unsqueeze(0).transpose(1, 2) |
|
loss_ce = criterion(pred_logits, target_classes) |
|
all_ce_loss.append(loss_ce) |
|
losses = {"loss_cross_entropy": torch.cat(all_ce_loss, dim=-1).mean()} |
|
return losses |
|
|
|
def Mask2FormerLoss_loss_masks( |
|
self, |
|
masks_queries_logits: torch.Tensor, |
|
mask_labels: List[torch.Tensor], |
|
indices: Tuple[np.array], |
|
num_masks: int |
|
) -> Dict[str, torch.Tensor]: |
|
src_idx = self._get_predictions_permutation_indices(indices) |
|
tgt_idx = self._get_targets_permutation_indices(indices) |
|
|
|
pred_masks = masks_queries_logits[src_idx] |
|
|
|
|
|
target_masks, _ = self._pad_images_to_max_in_batch(mask_labels) |
|
target_masks = target_masks[tgt_idx] |
|
|
|
|
|
pred_masks = pred_masks[:, None] |
|
target_masks = target_masks[:, None] |
|
|
|
|
|
with torch.no_grad(): |
|
point_coordinates = self.sample_points_using_uncertainty( |
|
pred_masks, |
|
lambda logits: self.calculate_uncertainty(logits), |
|
self.num_points, |
|
self.oversample_ratio, |
|
self.importance_sample_ratio, |
|
) |
|
point_labels = sample_point(target_masks.to(torch.bfloat16), point_coordinates.to(torch.bfloat16), align_corners=False).squeeze(1) |
|
|
|
point_logits = sample_point(pred_masks, point_coordinates.to(pred_masks.dtype), align_corners=False).squeeze(1) |
|
|
|
losses = { |
|
"loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks), |
|
"loss_dice": dice_loss(point_logits, point_labels, num_masks), |
|
} |
|
|
|
del pred_masks |
|
del target_masks |
|
return losses |
|
|
|
def Mask2FormerLoss_sample_points_using_uncertainty( |
|
self, |
|
logits: torch.Tensor, |
|
uncertainty_function, |
|
num_points: int, |
|
oversample_ratio: int, |
|
importance_sample_ratio: float, |
|
) -> torch.Tensor: |
|
|
|
num_boxes = logits.shape[0] |
|
num_points_sampled = int(num_points * oversample_ratio) |
|
|
|
|
|
point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device) |
|
|
|
point_logits = sample_point(logits, point_coordinates.to(logits.dtype), align_corners=False) |
|
|
|
point_uncertainties = uncertainty_function(point_logits) |
|
|
|
num_uncertain_points = int(importance_sample_ratio * num_points) |
|
num_random_points = num_points - num_uncertain_points |
|
|
|
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] |
|
shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device) |
|
idx += shift[:, None] |
|
point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2) |
|
|
|
if num_random_points > 0: |
|
point_coordinates = torch.cat( |
|
[point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)], |
|
dim=1, |
|
) |
|
return point_coordinates |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def Mask2FormerHungarianMatcher_forward( |
|
self, |
|
masks_queries_logits: torch.Tensor, |
|
class_queries_logits: torch.Tensor, |
|
mask_labels: torch.Tensor, |
|
class_labels: torch.Tensor, |
|
) -> List[Tuple[Tensor]]: |
|
indices: List[Tuple[np.array]] = [] |
|
|
|
|
|
batch_size = masks_queries_logits.shape[0] |
|
for i in range(batch_size): |
|
pred_probs = class_queries_logits[i].softmax(-1) |
|
pred_mask = masks_queries_logits[i] |
|
|
|
|
|
cost_class = -pred_probs[:, class_labels[i]] |
|
target_mask = mask_labels[i].to(pred_mask) |
|
target_mask = target_mask[:, None] |
|
pred_mask = pred_mask[:, None] |
|
|
|
|
|
point_coordinates = torch.rand(1, self.num_points, 2, device=pred_mask.device) |
|
|
|
target_coordinates = point_coordinates.repeat(target_mask.shape[0], 1, 1).to(target_mask.dtype) |
|
target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1) |
|
|
|
pred_coordinates = point_coordinates.repeat(pred_mask.shape[0], 1, 1).to(pred_mask.dtype) |
|
pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1) |
|
|
|
|
|
cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask) |
|
|
|
cost_dice = pair_wise_dice_loss(pred_mask, target_mask) |
|
|
|
cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice |
|
|
|
cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10)) |
|
cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10)) |
|
cost_matrix = torch.nan_to_num(cost_matrix, 0) |
|
|
|
assigned_indices: Tuple[np.array] = linear_sum_assignment(cost_matrix.to(torch.float32).cpu()) |
|
indices.append(assigned_indices) |
|
|
|
|
|
matched_indices = [ |
|
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices |
|
] |
|
return matched_indices |
|
|
|
|
|
|
|
|
|
def Mask2FormerMaskedAttentionDecoder_forward_first3layers( |
|
self, |
|
inputs_embeds: torch.Tensor = None, |
|
multi_stage_positional_embeddings: torch.Tensor = None, |
|
pixel_embeddings: torch.Tensor = None, |
|
encoder_hidden_states: torch.Tensor = None, |
|
query_position_embeddings: torch.Tensor = None, |
|
feature_size_list: List = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
r""" |
|
Args: |
|
inputs_embeds (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): |
|
The query embeddings that are passed into the decoder. |
|
multi_stage_positional_embeddings (`torch.FloatTensor` of shape `(height*width, batch_size, num_channels)`): |
|
Position embeddings that are added to the keys in each cross(masked)-attention layer. |
|
pixel_embeddings (`torch.FloatTensor`): |
|
Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel |
|
Decoder. |
|
query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): |
|
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. |
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`): |
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the |
|
cross(masked)-attention of the decoder. |
|
feature_size_list (`List[torch.Size]`): |
|
This is a list containing shapes (height & width) of multi-scale features from the Pixel Decoder. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
|
for more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if inputs_embeds is not None: |
|
hidden_states = inputs_embeds |
|
|
|
|
|
intermediate = () |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
attentions = () if output_attentions else None |
|
|
|
|
|
intermediate_mask_predictions = () |
|
|
|
intermediate_hidden_states = self.layernorm(inputs_embeds) |
|
intermediate += (intermediate_hidden_states,) |
|
|
|
predicted_mask, attention_mask = self.mask_predictor( |
|
intermediate_hidden_states, pixel_embeddings, feature_size_list[0] |
|
) |
|
intermediate_mask_predictions += (predicted_mask,) |
|
|
|
for idx, decoder_layer in enumerate(self.layers[:3]): |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
dropout_probability = torch.rand([]) |
|
|
|
if self.training and (dropout_probability < self.layerdrop): |
|
continue |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
decoder_layer.__call__, |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states, |
|
None, |
|
None, |
|
output_attentions, |
|
) |
|
|
|
else: |
|
level_index = idx % self.num_feature_levels |
|
|
|
where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype) |
|
|
|
attention_mask = attention_mask * where.unsqueeze(-1) |
|
|
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
level_index=level_index, |
|
position_embeddings=multi_stage_positional_embeddings, |
|
query_position_embeddings=query_position_embeddings, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
intermediate_hidden_states = self.layernorm(layer_outputs[0]) |
|
|
|
predicted_mask, attention_mask = self.mask_predictor( |
|
intermediate_hidden_states, |
|
pixel_embeddings, |
|
feature_size_list[(idx + 1) % self.num_feature_levels], |
|
) |
|
|
|
intermediate_mask_predictions += (predicted_mask,) |
|
|
|
|
|
intermediate += (intermediate_hidden_states,) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
attentions += (layer_outputs[1],) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
hidden_states = hidden_states.transpose(1, 0) |
|
if not return_dict: |
|
outputs = [hidden_states, all_hidden_states, attentions, intermediate, intermediate_mask_predictions] |
|
return tuple(v for v in outputs if v is not None) |
|
|
|
return Mask2FormerMaskedAttentionDecoderOutput( |
|
last_hidden_state=hidden_states, |
|
hidden_states=all_hidden_states, |
|
attentions=attentions, |
|
intermediate_hidden_states=intermediate, |
|
masks_queries_logits=intermediate_mask_predictions, |
|
) |
|
|
|
|
|
def Mask2FormerMaskedAttentionDecoder_forward_last3layers( |
|
self, |
|
inputs_embeds: torch.Tensor = None, |
|
multi_stage_positional_embeddings: torch.Tensor = None, |
|
pixel_embeddings: torch.Tensor = None, |
|
encoder_hidden_states: torch.Tensor = None, |
|
query_position_embeddings: torch.Tensor = None, |
|
feature_size_list: List = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
r""" |
|
Args: |
|
inputs_embeds (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): |
|
The query embeddings that are passed into the decoder. |
|
multi_stage_positional_embeddings (`torch.FloatTensor` of shape `(height*width, batch_size, num_channels)`): |
|
Position embeddings that are added to the keys in each cross(masked)-attention layer. |
|
pixel_embeddings (`torch.FloatTensor`): |
|
Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel |
|
Decoder. |
|
query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): |
|
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. |
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`): |
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the |
|
cross(masked)-attention of the decoder. |
|
feature_size_list (`List[torch.Size]`): |
|
This is a list containing shapes (height & width) of multi-scale features from the Pixel Decoder. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
|
for more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if inputs_embeds is not None: |
|
hidden_states = inputs_embeds |
|
|
|
|
|
intermediate = () |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
attentions = () if output_attentions else None |
|
|
|
|
|
intermediate_mask_predictions = () |
|
|
|
intermediate_hidden_states = self.layernorm(inputs_embeds) |
|
intermediate += (intermediate_hidden_states,) |
|
|
|
predicted_mask, attention_mask = self.mask_predictor( |
|
intermediate_hidden_states, pixel_embeddings, feature_size_list[0] |
|
) |
|
intermediate_mask_predictions += (predicted_mask,) |
|
|
|
for _idx, decoder_layer in enumerate(self.layers[3:]): |
|
idx = _idx + 3 |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
dropout_probability = torch.rand([]) |
|
|
|
if self.training and (dropout_probability < self.layerdrop): |
|
continue |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
decoder_layer.__call__, |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states, |
|
None, |
|
None, |
|
output_attentions, |
|
) |
|
|
|
else: |
|
level_index = idx % self.num_feature_levels |
|
|
|
where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype) |
|
|
|
attention_mask = attention_mask * where.unsqueeze(-1) |
|
|
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
level_index=level_index, |
|
position_embeddings=multi_stage_positional_embeddings, |
|
query_position_embeddings=query_position_embeddings, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
intermediate_hidden_states = self.layernorm(layer_outputs[0]) |
|
|
|
predicted_mask, attention_mask = self.mask_predictor( |
|
intermediate_hidden_states, |
|
pixel_embeddings, |
|
feature_size_list[(idx + 1) % self.num_feature_levels], |
|
) |
|
|
|
intermediate_mask_predictions += (predicted_mask,) |
|
|
|
|
|
intermediate += (intermediate_hidden_states,) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
attentions += (layer_outputs[1],) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
hidden_states = hidden_states.transpose(1, 0) |
|
if not return_dict: |
|
outputs = [hidden_states, all_hidden_states, attentions, intermediate, intermediate_mask_predictions] |
|
return tuple(v for v in outputs if v is not None) |
|
|
|
return Mask2FormerMaskedAttentionDecoderOutput( |
|
last_hidden_state=hidden_states, |
|
hidden_states=all_hidden_states, |
|
attentions=attentions, |
|
intermediate_hidden_states=intermediate, |
|
masks_queries_logits=intermediate_mask_predictions, |
|
) |
|
|
|
|
|
def Mask2FormerTransformerModule_forward_first_part( |
|
self, |
|
multi_scale_features: List[Tensor], |
|
mask_features: Tensor, |
|
output_hidden_states: bool = False, |
|
output_attentions: bool = False, |
|
) -> Mask2FormerMaskedAttentionDecoderOutput: |
|
multi_stage_features = [] |
|
multi_stage_positional_embeddings = [] |
|
size_list = [] |
|
|
|
for i in range(self.num_feature_levels): |
|
size_list.append(multi_scale_features[i].shape[-2:]) |
|
multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2)) |
|
multi_stage_features.append( |
|
self.input_projections[i](multi_scale_features[i]).flatten(2) |
|
+ self.level_embed.weight[i][None, :, None] |
|
) |
|
|
|
|
|
multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1) |
|
multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1) |
|
|
|
_, batch_size, _ = multi_stage_features[0].shape |
|
|
|
|
|
query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1) |
|
query_features = self.queries_features.weight.unsqueeze(1).repeat(1, batch_size, 1) |
|
|
|
decoder_output = self.decoder.Mask2FormerMaskedAttentionDecoder_forward_first3layers( |
|
inputs_embeds=query_features, |
|
multi_stage_positional_embeddings=multi_stage_positional_embeddings, |
|
pixel_embeddings=mask_features, |
|
encoder_hidden_states=multi_stage_features, |
|
query_position_embeddings=query_embeddings, |
|
feature_size_list=size_list, |
|
output_hidden_states=output_hidden_states, |
|
output_attentions=output_attentions, |
|
return_dict=True, |
|
) |
|
|
|
return decoder_output |
|
|
|
|
|
def Mask2FormerTransformerModule_forward_second_part( |
|
self, |
|
query_features: Tensor, |
|
query_embeddings: Tensor, |
|
multi_scale_features: List[Tensor], |
|
mask_features: Tensor, |
|
output_hidden_states: bool = False, |
|
output_attentions: bool = False, |
|
) -> Mask2FormerMaskedAttentionDecoderOutput: |
|
multi_stage_features = [] |
|
multi_stage_positional_embeddings = [] |
|
size_list = [] |
|
|
|
for i in range(self.num_feature_levels): |
|
size_list.append(multi_scale_features[i].shape[-2:]) |
|
multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2)) |
|
multi_stage_features.append( |
|
self.input_projections[i](multi_scale_features[i]).flatten(2) |
|
+ self.level_embed.weight[i][None, :, None] |
|
) |
|
|
|
|
|
multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1) |
|
multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1) |
|
|
|
_, batch_size, _ = multi_stage_features[0].shape |
|
|
|
|
|
|
|
|
|
|
|
decoder_output = self.decoder.Mask2FormerMaskedAttentionDecoder_forward_last3layers( |
|
inputs_embeds=query_features, |
|
multi_stage_positional_embeddings=multi_stage_positional_embeddings, |
|
pixel_embeddings=mask_features, |
|
encoder_hidden_states=multi_stage_features, |
|
query_position_embeddings=query_embeddings, |
|
feature_size_list=size_list, |
|
output_hidden_states=output_hidden_states, |
|
output_attentions=output_attentions, |
|
return_dict=True, |
|
) |
|
|
|
return decoder_output |
|
|
|
|
|
def Mask2FormerModel_forward_first_part( |
|
self, |
|
pixel_values: Tensor, |
|
pixel_mask: Optional[Tensor] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Mask2FormerModelOutput: |
|
r""" |
|
Returns: |
|
`Mask2FormerModelOutput` |
|
|
|
Examples: |
|
```python |
|
>>> import torch |
|
>>> from PIL import Image |
|
>>> import requests |
|
>>> from transformers import AutoImageProcessor, Mask2FormerModel |
|
|
|
>>> # load image |
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> # load image preprocessor and Mask2FormerModel trained on COCO instance segmentation dataset |
|
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance") |
|
>>> model = Mask2FormerModel.from_pretrained("facebook/mask2former-swin-small-coco-instance") |
|
>>> inputs = image_processor(image, return_tensors="pt") |
|
|
|
>>> # forward pass |
|
>>> with torch.no_grad(): |
|
... outputs = model(**inputs) |
|
|
|
>>> # model outputs last hidden states of shape (batch_size, num_queries, hidden_size) |
|
>>> print(outputs.transformer_decoder_last_hidden_state.shape) |
|
torch.Size([1, 100, 256]) |
|
``` |
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
batch_size, _, height, width = pixel_values.shape |
|
|
|
if pixel_mask is None: |
|
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) |
|
|
|
pixel_level_module_output = self.pixel_level_module( |
|
pixel_values=pixel_values, output_hidden_states=output_hidden_states |
|
) |
|
|
|
transformer_module_output = self.transformer_module.Mask2FormerTransformerModule_forward_first_part( |
|
multi_scale_features=pixel_level_module_output.decoder_hidden_states, |
|
mask_features=pixel_level_module_output.decoder_last_hidden_state, |
|
output_hidden_states=True, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
query_features = transformer_module_output.last_hidden_state |
|
return query_features, pixel_level_module_output |
|
|
|
|
|
def Mask2FormerModel_forward_second_part( |
|
self, |
|
query_features: Tensor, |
|
query_embeddings: Tensor, |
|
pixel_level_module_output, |
|
pixel_values: Tensor, |
|
pixel_mask: Optional[Tensor] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Mask2FormerModelOutput: |
|
r""" |
|
Returns: |
|
`Mask2FormerModelOutput` |
|
|
|
Examples: |
|
```python |
|
>>> import torch |
|
>>> from PIL import Image |
|
>>> import requests |
|
>>> from transformers import AutoImageProcessor, Mask2FormerModel |
|
|
|
>>> # load image |
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> # load image preprocessor and Mask2FormerModel trained on COCO instance segmentation dataset |
|
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance") |
|
>>> model = Mask2FormerModel.from_pretrained("facebook/mask2former-swin-small-coco-instance") |
|
>>> inputs = image_processor(image, return_tensors="pt") |
|
|
|
>>> # forward pass |
|
>>> with torch.no_grad(): |
|
... outputs = model(**inputs) |
|
|
|
>>> # model outputs last hidden states of shape (batch_size, num_queries, hidden_size) |
|
>>> print(outputs.transformer_decoder_last_hidden_state.shape) |
|
torch.Size([1, 100, 256]) |
|
``` |
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
batch_size, _, height, width = pixel_values.shape |
|
|
|
if pixel_mask is None: |
|
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) |
|
|
|
transformer_module_output = self.transformer_module.Mask2FormerTransformerModule_forward_second_part( |
|
query_features=query_features, |
|
query_embeddings=query_embeddings, |
|
multi_scale_features=pixel_level_module_output.decoder_hidden_states, |
|
mask_features=pixel_level_module_output.decoder_last_hidden_state, |
|
output_hidden_states=True, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
encoder_hidden_states = None |
|
pixel_decoder_hidden_states = None |
|
transformer_decoder_hidden_states = None |
|
transformer_decoder_intermediate_states = None |
|
|
|
if output_hidden_states: |
|
encoder_hidden_states = pixel_level_module_output.encoder_hidden_states |
|
pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states |
|
transformer_decoder_hidden_states = transformer_module_output.hidden_states |
|
transformer_decoder_intermediate_states = transformer_module_output.intermediate_hidden_states |
|
|
|
output = Mask2FormerModelOutput( |
|
encoder_last_hidden_state=pixel_level_module_output.encoder_last_hidden_state, |
|
pixel_decoder_last_hidden_state=pixel_level_module_output.decoder_last_hidden_state, |
|
transformer_decoder_last_hidden_state=transformer_module_output.last_hidden_state, |
|
encoder_hidden_states=encoder_hidden_states, |
|
pixel_decoder_hidden_states=pixel_decoder_hidden_states, |
|
transformer_decoder_hidden_states=transformer_decoder_hidden_states, |
|
transformer_decoder_intermediate_states=transformer_decoder_intermediate_states, |
|
attentions=transformer_module_output.attentions, |
|
masks_queries_logits=transformer_module_output.masks_queries_logits, |
|
) |
|
|
|
if not return_dict: |
|
output = tuple(v for v in output.values() if v is not None) |
|
|
|
return output |
|
|
|
|
|
def Mask2FormerForUniversalSegmentation_forward_first_part( |
|
self, |
|
pixel_values: Tensor, |
|
mask_labels: Optional[List[Tensor]] = None, |
|
class_labels: Optional[List[Tensor]] = None, |
|
pixel_mask: Optional[Tensor] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_auxiliary_logits: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Mask2FormerForUniversalSegmentationOutput: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
query_features, pixel_level_module_output = self.model.Mask2FormerModel_forward_first_part( |
|
pixel_values=pixel_values, |
|
pixel_mask=pixel_mask, |
|
output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss, |
|
output_attentions=output_attentions, |
|
return_dict=True, |
|
) |
|
|
|
return query_features, pixel_level_module_output |
|
|
|
|
|
def Mask2FormerForUniversalSegmentation_forward_second_part( |
|
self, |
|
query_features, |
|
query_embeddings, |
|
pixel_level_module_output, |
|
text_classifier, |
|
pixel_values: Tensor, |
|
mask_labels: Optional[List[Tensor]] = None, |
|
class_labels: Optional[List[Tensor]] = None, |
|
pixel_mask: Optional[Tensor] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_auxiliary_logits: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Mask2FormerForUniversalSegmentationOutput: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.model.Mask2FormerModel_forward_second_part( |
|
query_features=query_features, |
|
query_embeddings=query_embeddings, |
|
pixel_level_module_output=pixel_level_module_output, |
|
pixel_values=pixel_values, |
|
pixel_mask=pixel_mask, |
|
output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss, |
|
output_attentions=output_attentions, |
|
return_dict=True, |
|
) |
|
|
|
loss, loss_dict, auxiliary_logits = None, None, None |
|
class_queries_logits = () |
|
|
|
for decoder_output in outputs.transformer_decoder_intermediate_states: |
|
class_prediction = self.ov_class_predictor(decoder_output.transpose(0, 1), text_classifier) |
|
|
|
class_queries_logits += (class_prediction,) |
|
|
|
masks_queries_logits = outputs.masks_queries_logits |
|
|
|
auxiliary_logits = self.get_auxiliary_logits(class_queries_logits, masks_queries_logits) |
|
|
|
if mask_labels is not None and class_labels is not None: |
|
loss_dict = self.get_loss_dict( |
|
masks_queries_logits=masks_queries_logits[-1], |
|
class_queries_logits=class_queries_logits[-1], |
|
mask_labels=mask_labels, |
|
class_labels=class_labels, |
|
auxiliary_predictions=auxiliary_logits, |
|
) |
|
loss = self.get_loss(loss_dict) |
|
|
|
encoder_hidden_states = None |
|
pixel_decoder_hidden_states = None |
|
transformer_decoder_hidden_states = None |
|
|
|
if output_hidden_states: |
|
encoder_hidden_states = outputs.encoder_hidden_states |
|
pixel_decoder_hidden_states = outputs.pixel_decoder_hidden_states |
|
transformer_decoder_hidden_states = outputs.transformer_decoder_hidden_states |
|
|
|
output_auxiliary_logits = ( |
|
self.config.output_auxiliary_logits if output_auxiliary_logits is None else output_auxiliary_logits |
|
) |
|
if not output_auxiliary_logits: |
|
auxiliary_logits = None |
|
|
|
output = Mask2FormerForUniversalSegmentationOutput( |
|
loss=loss, |
|
class_queries_logits=class_queries_logits[-1], |
|
masks_queries_logits=masks_queries_logits[-1], |
|
auxiliary_logits=auxiliary_logits, |
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
|
pixel_decoder_last_hidden_state=outputs.pixel_decoder_last_hidden_state, |
|
transformer_decoder_last_hidden_state=outputs.transformer_decoder_last_hidden_state, |
|
encoder_hidden_states=encoder_hidden_states, |
|
pixel_decoder_hidden_states=pixel_decoder_hidden_states, |
|
transformer_decoder_hidden_states=transformer_decoder_hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
if not return_dict: |
|
output = tuple(v for v in output.values() if v is not None) |
|
if loss is not None: |
|
output = (loss) + output |
|
return output |
|
|