Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init | |
from mmcv.cnn.bricks.transformer import (build_positional_encoding, | |
build_transformer_layer_sequence) | |
from mmcv.ops import point_sample | |
from mmcv.runner import ModuleList | |
from mmdet.core import build_assigner, build_sampler, reduce_mean | |
from mmdet.models.utils import get_uncertain_point_coords_with_randomness | |
from ..builder import HEADS, build_loss | |
from .anchor_free_head import AnchorFreeHead | |
from .maskformer_head import MaskFormerHead | |
class Mask2FormerHead(MaskFormerHead): | |
"""Implements the Mask2Former head. | |
See `Masked-attention Mask Transformer for Universal Image | |
Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details. | |
Args: | |
in_channels (list[int]): Number of channels in the input feature map. | |
feat_channels (int): Number of channels for features. | |
out_channels (int): Number of channels for output. | |
num_things_classes (int): Number of things. | |
num_stuff_classes (int): Number of stuff. | |
num_queries (int): Number of query in Transformer decoder. | |
pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel | |
decoder. Defaults to None. | |
enforce_decoder_input_project (bool, optional): Whether to add | |
a layer to change the embed_dim of tranformer encoder in | |
pixel decoder to the embed_dim of transformer decoder. | |
Defaults to False. | |
transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for | |
transformer decoder. Defaults to None. | |
positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for | |
transformer decoder position encoding. Defaults to None. | |
loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification | |
loss. Defaults to None. | |
loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss. | |
Defaults to None. | |
loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss. | |
Defaults to None. | |
train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of | |
Mask2Former head. | |
test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of | |
Mask2Former head. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
in_channels, | |
feat_channels, | |
out_channels, | |
num_things_classes=80, | |
num_stuff_classes=53, | |
num_queries=100, | |
num_transformer_feat_level=3, | |
pixel_decoder=None, | |
enforce_decoder_input_project=False, | |
transformer_decoder=None, | |
positional_encoding=None, | |
loss_cls=None, | |
loss_mask=None, | |
loss_dice=None, | |
train_cfg=None, | |
test_cfg=None, | |
init_cfg=None, | |
**kwargs): | |
super(AnchorFreeHead, self).__init__(init_cfg) | |
self.num_things_classes = num_things_classes | |
self.num_stuff_classes = num_stuff_classes | |
self.num_classes = self.num_things_classes + self.num_stuff_classes | |
self.num_queries = num_queries | |
self.num_transformer_feat_level = num_transformer_feat_level | |
self.num_heads = transformer_decoder.transformerlayers.\ | |
attn_cfgs.num_heads | |
self.num_transformer_decoder_layers = transformer_decoder.num_layers | |
assert pixel_decoder.encoder.transformerlayers.\ | |
attn_cfgs.num_levels == num_transformer_feat_level | |
pixel_decoder_ = copy.deepcopy(pixel_decoder) | |
pixel_decoder_.update( | |
in_channels=in_channels, | |
feat_channels=feat_channels, | |
out_channels=out_channels) | |
self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1] | |
self.transformer_decoder = build_transformer_layer_sequence( | |
transformer_decoder) | |
self.decoder_embed_dims = self.transformer_decoder.embed_dims | |
self.decoder_input_projs = ModuleList() | |
# from low resolution to high resolution | |
for _ in range(num_transformer_feat_level): | |
if (self.decoder_embed_dims != feat_channels | |
or enforce_decoder_input_project): | |
self.decoder_input_projs.append( | |
Conv2d( | |
feat_channels, self.decoder_embed_dims, kernel_size=1)) | |
else: | |
self.decoder_input_projs.append(nn.Identity()) | |
self.decoder_positional_encoding = build_positional_encoding( | |
positional_encoding) | |
self.query_embed = nn.Embedding(self.num_queries, feat_channels) | |
self.query_feat = nn.Embedding(self.num_queries, feat_channels) | |
# from low resolution to high resolution | |
self.level_embed = nn.Embedding(self.num_transformer_feat_level, | |
feat_channels) | |
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) | |
self.mask_embed = nn.Sequential( | |
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), | |
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), | |
nn.Linear(feat_channels, out_channels)) | |
self.test_cfg = test_cfg | |
self.train_cfg = train_cfg | |
if train_cfg: | |
self.assigner = build_assigner(self.train_cfg.assigner) | |
self.sampler = build_sampler(self.train_cfg.sampler, context=self) | |
self.num_points = self.train_cfg.get('num_points', 12544) | |
self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) | |
self.importance_sample_ratio = self.train_cfg.get( | |
'importance_sample_ratio', 0.75) | |
self.class_weight = loss_cls.class_weight | |
self.loss_cls = build_loss(loss_cls) | |
self.loss_mask = build_loss(loss_mask) | |
self.loss_dice = build_loss(loss_dice) | |
def init_weights(self): | |
for m in self.decoder_input_projs: | |
if isinstance(m, Conv2d): | |
caffe2_xavier_init(m, bias=0) | |
self.pixel_decoder.init_weights() | |
for p in self.transformer_decoder.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_normal_(p) | |
def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, | |
img_metas): | |
"""Compute classification and mask targets for one image. | |
Args: | |
cls_score (Tensor): Mask score logits from a single decoder layer | |
for one image. Shape (num_queries, cls_out_channels). | |
mask_pred (Tensor): Mask logits for a single decoder layer for one | |
image. Shape (num_queries, h, w). | |
gt_labels (Tensor): Ground truth class indices for one image with | |
shape (num_gts, ). | |
gt_masks (Tensor): Ground truth mask for each image, each with | |
shape (num_gts, h, w). | |
img_metas (dict): Image informtation. | |
Returns: | |
tuple[Tensor]: A tuple containing the following for one image. | |
- labels (Tensor): Labels of each image. \ | |
shape (num_queries, ). | |
- label_weights (Tensor): Label weights of each image. \ | |
shape (num_queries, ). | |
- mask_targets (Tensor): Mask targets of each image. \ | |
shape (num_queries, h, w). | |
- mask_weights (Tensor): Mask weights of each image. \ | |
shape (num_queries, ). | |
- pos_inds (Tensor): Sampled positive indices for each \ | |
image. | |
- neg_inds (Tensor): Sampled negative indices for each \ | |
image. | |
""" | |
# sample points | |
num_queries = cls_score.shape[0] | |
num_gts = gt_labels.shape[0] | |
point_coords = torch.rand((1, self.num_points, 2), | |
device=cls_score.device) | |
# shape (num_queries, num_points) | |
mask_points_pred = point_sample( | |
mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, | |
1)).squeeze(1) | |
# shape (num_gts, num_points) | |
gt_points_masks = point_sample( | |
gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, | |
1)).squeeze(1) | |
# assign and sample | |
assign_result = self.assigner.assign(cls_score, mask_points_pred, | |
gt_labels, gt_points_masks, | |
img_metas) | |
sampling_result = self.sampler.sample(assign_result, mask_pred, | |
gt_masks) | |
pos_inds = sampling_result.pos_inds | |
neg_inds = sampling_result.neg_inds | |
# label target | |
labels = gt_labels.new_full((self.num_queries, ), | |
self.num_classes, | |
dtype=torch.long) | |
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] | |
label_weights = gt_labels.new_ones((self.num_queries, )) | |
# mask target | |
mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] | |
mask_weights = mask_pred.new_zeros((self.num_queries, )) | |
mask_weights[pos_inds] = 1.0 | |
return (labels, label_weights, mask_targets, mask_weights, pos_inds, | |
neg_inds) | |
def loss_single(self, cls_scores, mask_preds, gt_labels_list, | |
gt_masks_list, img_metas): | |
"""Loss function for outputs from a single decoder layer. | |
Args: | |
cls_scores (Tensor): Mask score logits from a single decoder layer | |
for all images. Shape (batch_size, num_queries, | |
cls_out_channels). Note `cls_out_channels` should includes | |
background. | |
mask_preds (Tensor): Mask logits for a pixel decoder for all | |
images. Shape (batch_size, num_queries, h, w). | |
gt_labels_list (list[Tensor]): Ground truth class indices for each | |
image, each with shape (num_gts, ). | |
gt_masks_list (list[Tensor]): Ground truth mask for each image, | |
each with shape (num_gts, h, w). | |
img_metas (list[dict]): List of image meta information. | |
Returns: | |
tuple[Tensor]: Loss components for outputs from a single \ | |
decoder layer. | |
""" | |
num_imgs = cls_scores.size(0) | |
cls_scores_list = [cls_scores[i] for i in range(num_imgs)] | |
mask_preds_list = [mask_preds[i] for i in range(num_imgs)] | |
(labels_list, label_weights_list, mask_targets_list, mask_weights_list, | |
num_total_pos, | |
num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list, | |
gt_labels_list, gt_masks_list, | |
img_metas) | |
# shape (batch_size, num_queries) | |
labels = torch.stack(labels_list, dim=0) | |
# shape (batch_size, num_queries) | |
label_weights = torch.stack(label_weights_list, dim=0) | |
# shape (num_total_gts, h, w) | |
mask_targets = torch.cat(mask_targets_list, dim=0) | |
# shape (batch_size, num_queries) | |
mask_weights = torch.stack(mask_weights_list, dim=0) | |
# classfication loss | |
# shape (batch_size * num_queries, ) | |
cls_scores = cls_scores.flatten(0, 1) | |
labels = labels.flatten(0, 1) | |
label_weights = label_weights.flatten(0, 1) | |
class_weight = cls_scores.new_tensor(self.class_weight) | |
loss_cls = self.loss_cls( | |
cls_scores, | |
labels, | |
label_weights, | |
avg_factor=class_weight[labels].sum()) | |
num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) | |
num_total_masks = max(num_total_masks, 1) | |
# extract positive ones | |
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) | |
mask_preds = mask_preds[mask_weights > 0] | |
if mask_targets.shape[0] == 0: | |
# zero match | |
loss_dice = mask_preds.sum() | |
loss_mask = mask_preds.sum() | |
return loss_cls, loss_mask, loss_dice | |
with torch.no_grad(): | |
points_coords = get_uncertain_point_coords_with_randomness( | |
mask_preds.unsqueeze(1), None, self.num_points, | |
self.oversample_ratio, self.importance_sample_ratio) | |
# shape (num_total_gts, h, w) -> (num_total_gts, num_points) | |
mask_point_targets = point_sample( | |
mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) | |
# shape (num_queries, h, w) -> (num_queries, num_points) | |
mask_point_preds = point_sample( | |
mask_preds.unsqueeze(1), points_coords).squeeze(1) | |
# dice loss | |
loss_dice = self.loss_dice( | |
mask_point_preds, mask_point_targets, avg_factor=num_total_masks) | |
# mask loss | |
# shape (num_queries, num_points) -> (num_queries * num_points, ) | |
mask_point_preds = mask_point_preds.reshape(-1) | |
# shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) | |
mask_point_targets = mask_point_targets.reshape(-1) | |
loss_mask = self.loss_mask( | |
mask_point_preds, | |
mask_point_targets, | |
avg_factor=num_total_masks * self.num_points) | |
return loss_cls, loss_mask, loss_dice | |
def forward_head(self, decoder_out, mask_feature, attn_mask_target_size): | |
"""Forward for head part which is called after every decoder layer. | |
Args: | |
decoder_out (Tensor): in shape (num_queries, batch_size, c). | |
mask_feature (Tensor): in shape (batch_size, c, h, w). | |
attn_mask_target_size (tuple[int, int]): target attention | |
mask size. | |
Returns: | |
tuple: A tuple contain three elements. | |
- cls_pred (Tensor): Classification scores in shape \ | |
(batch_size, num_queries, cls_out_channels). \ | |
Note `cls_out_channels` should includes background. | |
- mask_pred (Tensor): Mask scores in shape \ | |
(batch_size, num_queries,h, w). | |
- attn_mask (Tensor): Attention mask in shape \ | |
(batch_size * num_heads, num_queries, h, w). | |
""" | |
decoder_out = self.transformer_decoder.post_norm(decoder_out) | |
decoder_out = decoder_out.transpose(0, 1) | |
# shape (batch_size, num_queries, c) | |
cls_pred = self.cls_embed(decoder_out) | |
# shape (batch_size, num_queries, c) | |
mask_embed = self.mask_embed(decoder_out) | |
# shape (batch_size, num_queries, h, w) | |
mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) | |
attn_mask = F.interpolate( | |
mask_pred, | |
attn_mask_target_size, | |
mode='bilinear', | |
align_corners=False) | |
# shape (batch_size, num_queries, h, w) -> | |
# (batch_size * num_head, num_queries, h*w) | |
attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( | |
(1, self.num_heads, 1, 1)).flatten(0, 1) | |
attn_mask = attn_mask.sigmoid() < 0.5 | |
attn_mask = attn_mask.detach() | |
return cls_pred, mask_pred, attn_mask | |
def forward(self, feats, img_metas): | |
"""Forward function. | |
Args: | |
feats (list[Tensor]): Multi scale Features from the | |
upstream network, each is a 4D-tensor. | |
img_metas (list[dict]): List of image information. | |
Returns: | |
tuple: A tuple contains two elements. | |
- cls_pred_list (list[Tensor)]: Classification logits \ | |
for each decoder layer. Each is a 3D-tensor with shape \ | |
(batch_size, num_queries, cls_out_channels). \ | |
Note `cls_out_channels` should includes background. | |
- mask_pred_list (list[Tensor]): Mask logits for each \ | |
decoder layer. Each with shape (batch_size, num_queries, \ | |
h, w). | |
""" | |
batch_size = len(img_metas) | |
mask_features, multi_scale_memorys = self.pixel_decoder(feats) | |
# multi_scale_memorys (from low resolution to high resolution) | |
decoder_inputs = [] | |
decoder_positional_encodings = [] | |
for i in range(self.num_transformer_feat_level): | |
decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) | |
# shape (batch_size, c, h, w) -> (h*w, batch_size, c) | |
decoder_input = decoder_input.flatten(2).permute(2, 0, 1) | |
level_embed = self.level_embed.weight[i].view(1, 1, -1) | |
decoder_input = decoder_input + level_embed | |
# shape (batch_size, c, h, w) -> (h*w, batch_size, c) | |
mask = decoder_input.new_zeros( | |
(batch_size, ) + multi_scale_memorys[i].shape[-2:], | |
dtype=torch.bool) | |
decoder_positional_encoding = self.decoder_positional_encoding( | |
mask) | |
decoder_positional_encoding = decoder_positional_encoding.flatten( | |
2).permute(2, 0, 1) | |
decoder_inputs.append(decoder_input) | |
decoder_positional_encodings.append(decoder_positional_encoding) | |
# shape (num_queries, c) -> (num_queries, batch_size, c) | |
query_feat = self.query_feat.weight.unsqueeze(1).repeat( | |
(1, batch_size, 1)) | |
query_embed = self.query_embed.weight.unsqueeze(1).repeat( | |
(1, batch_size, 1)) | |
cls_pred_list = [] | |
mask_pred_list = [] | |
cls_pred, mask_pred, attn_mask = self.forward_head( | |
query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) | |
cls_pred_list.append(cls_pred) | |
mask_pred_list.append(mask_pred) | |
for i in range(self.num_transformer_decoder_layers): | |
level_idx = i % self.num_transformer_feat_level | |
# if a mask is all True(all background), then set it all False. | |
attn_mask[torch.where( | |
attn_mask.sum(-1) == attn_mask.shape[-1])] = False | |
# cross_attn + self_attn | |
layer = self.transformer_decoder.layers[i] | |
attn_masks = [attn_mask, None] | |
query_feat = layer( | |
query=query_feat, | |
key=decoder_inputs[level_idx], | |
value=decoder_inputs[level_idx], | |
query_pos=query_embed, | |
key_pos=decoder_positional_encodings[level_idx], | |
attn_masks=attn_masks, | |
query_key_padding_mask=None, | |
# here we do not apply masking on padded region | |
key_padding_mask=None) | |
cls_pred, mask_pred, attn_mask = self.forward_head( | |
query_feat, mask_features, multi_scale_memorys[ | |
(i + 1) % self.num_transformer_feat_level].shape[-2:]) | |
cls_pred_list.append(cls_pred) | |
mask_pred_list.append(mask_pred) | |
return cls_pred_list, mask_pred_list | |