# Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py import logging import fvcore.nn.weight_init as weight_init from typing import Optional import numpy as np import torch from torch import nn, Tensor from torch.nn import functional as F from detectron2.config import configurable from detectron2.layers import Conv2d from detectron2.utils.registry import Registry from .position_encoding import PositionEmbeddingSine TRANSFORMER_DECODER_REGISTRY = Registry("TRANSFORMER_MODULE") TRANSFORMER_DECODER_REGISTRY.__doc__ = """ Registry for transformer module in MaskFormer. """ def build_transformer_decoder(cfg, in_channels, mask_classification=True): """ Build a instance embedding branch from `cfg.MODEL.INS_EMBED_HEAD.NAME`. """ name = cfg.MODEL.MASK_FORMER.TRANSFORMER_DECODER_NAME return TRANSFORMER_DECODER_REGISTRY.get(name)(cfg, in_channels, mask_classification) def get_classification_logits(x, text_classifier, logit_scale, num_templates=None): # x in shape of [B, *, C] # text_classifier in shape of [num_classes, C] # logit_scale is a learnable scalar https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/model.py#L201 # return: [B, *, num_classes] x = F.normalize(x, dim=-1) logit_scale = torch.clamp(logit_scale.exp(), max=100) pred_logits = logit_scale * x @ text_classifier.T # B, *, N + 1 # max ensembel as in OpenSeg/ODISE final_pred_logits = [] cur_idx = 0 for num_t in num_templates: final_pred_logits.append(pred_logits[:, :, cur_idx: cur_idx + num_t].max(-1).values) cur_idx += num_t final_pred_logits.append(pred_logits[:, :, -1]) # the last classifier is for void final_pred_logits = torch.stack(final_pred_logits, dim=-1) return final_pred_logits # Ref: https://github.com/NVlabs/ODISE/blob/e97b06c424c575fec9fc5368dd4b3e050d91abc4/odise/modeling/meta_arch/odise.py#L923 class MaskPooling(nn.Module): def __init__( self, ): super().__init__() def forward(self, x, mask): """ Args: x: [B, C, H, W] mask: [B, Q, H, W] """ if not x.shape[-2:] == mask.shape[-2:]: # reshape mask to x mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False) with torch.no_grad(): mask = mask.detach() mask = (mask > 0).to(mask.dtype) denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8 mask_pooled_x = torch.einsum( "bchw,bqhw->bqc", x, mask / denorm, ) return mask_pooled_x class SelfAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.norm(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): if self.normalize_before: return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, query_pos) return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, query_pos) class CrossAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.norm(tgt) tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): if self.normalize_before: return self.forward_pre(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) return self.forward_post(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) class kMaXCrossAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.nhead = nhead self.v_proj = nn.Linear(d_model, d_model) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): # memory mask is in shape [B*num_heads, Q, HW] tgt2 = self.v_proj(memory) # HW x C clustering_result = memory_mask.view(tgt.shape[0], self.nhead, memory_mask.shape[-1])[:, 0] # [B, Q, HW] tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.norm(tgt) tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): if self.normalize_before: return self.forward_pre(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) return self.forward_post(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) class FFNLayer(nn.Module): def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False): super().__init__() # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm = nn.LayerNorm(d_model) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt): tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt): tgt2 = self.norm(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt): if self.normalize_before: return self.forward_pre(tgt) return self.forward_post(tgt) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(F"activation should be relu/gelu, not {activation}.") class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x @TRANSFORMER_DECODER_REGISTRY.register() class MultiScaleMaskedTransformerDecoder(nn.Module): @configurable def __init__( self, in_channels, mask_classification=True, *, num_classes: int, hidden_dim: int, num_queries: int, nheads: int, dim_feedforward: int, dec_layers: int, pre_norm: bool, mask_dim: int, enforce_input_project: bool, clip_embedding_dim: int ): """ NOTE: this interface is experimental. Args: in_channels: channels of the input features mask_classification: whether to add mask classifier or not num_classes: number of classes hidden_dim: Transformer feature dimension num_queries: number of queries nheads: number of heads dim_feedforward: feature dimension in feedforward network enc_layers: number of Transformer encoder layers dec_layers: number of Transformer decoder layers pre_norm: whether to use pre-LayerNorm or not mask_dim: mask feature dimension enforce_input_project: add input project 1x1 conv even if input channels and hidden dim is identical """ super().__init__() assert mask_classification, "Only support mask classification model" self.mask_classification = mask_classification # positional encoding N_steps = hidden_dim // 2 self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) # define Transformer decoder here self.num_heads = nheads self.num_layers = dec_layers self.transformer_self_attention_layers = nn.ModuleList() self.transformer_cross_attention_layers = nn.ModuleList() self.transformer_ffn_layers = nn.ModuleList() for _ in range(self.num_layers): self.transformer_self_attention_layers.append( SelfAttentionLayer( d_model=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm, ) ) self.transformer_cross_attention_layers.append( CrossAttentionLayer( d_model=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm, ) ) self.transformer_ffn_layers.append( FFNLayer( d_model=hidden_dim, dim_feedforward=dim_feedforward, dropout=0.0, normalize_before=pre_norm, ) ) self.decoder_norm = nn.LayerNorm(hidden_dim) self.num_queries = num_queries # learnable query features self.query_feat = nn.Embedding(num_queries, hidden_dim) # learnable query p.e. self.query_embed = nn.Embedding(num_queries, hidden_dim) # level embedding (we always use 3 scales) self.num_feature_levels = 3 self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) self.input_proj = nn.ModuleList() for _ in range(self.num_feature_levels): if in_channels != hidden_dim or enforce_input_project: self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) weight_init.c2_xavier_fill(self.input_proj[-1]) else: self.input_proj.append(nn.Sequential()) # output FFNs # if self.mask_classification: # self.class_embed = nn.Linear(hidden_dim, num_classes + 1) self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) # FC-CLIP self.mask_pooling = MaskPooling() self._mask_pooling_proj = nn.Sequential( nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, hidden_dim)) self.class_embed = MLP(hidden_dim, hidden_dim, clip_embedding_dim, 3) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) @classmethod def from_config(cls, cfg, in_channels, mask_classification): ret = {} ret["in_channels"] = in_channels ret["mask_classification"] = mask_classification ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES # Transformer parameters: ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD # NOTE: because we add learnable query features which requires supervision, # we add minus 1 to decoder layers to be consistent with our loss # implementation: that is, number of auxiliary losses is always # equal to number of decoder layers. With learnable query features, the number of # auxiliary losses equals number of decoders plus 1. assert cfg.MODEL.MASK_FORMER.DEC_LAYERS >= 1 ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS - 1 ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM ret["clip_embedding_dim"] = cfg.MODEL.FC_CLIP.EMBED_DIM return ret def forward(self, x, mask_features, mask = None, text_classifier=None, num_templates=None): # x is a list of multi-scale feature assert len(x) == self.num_feature_levels src = [] pos = [] size_list = [] # disable mask, it does not affect performance del mask for i in range(self.num_feature_levels): size_list.append(x[i].shape[-2:]) pos.append(self.pe_layer(x[i], None).flatten(2)) src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) # flatten NxCxHxW to HWxNxC pos[-1] = pos[-1].permute(2, 0, 1) src[-1] = src[-1].permute(2, 0, 1) _, bs, _ = src[0].shape # QxNxC query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) predictions_class = [] predictions_mask = [] # prediction heads on learnable query features outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], text_classifier=text_classifier, num_templates=num_templates) predictions_class.append(outputs_class) predictions_mask.append(outputs_mask) for i in range(self.num_layers): level_index = i % self.num_feature_levels attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False # attention: cross-attention first output = self.transformer_cross_attention_layers[i]( output, src[level_index], memory_mask=attn_mask, memory_key_padding_mask=None, # here we do not apply masking on padded region pos=pos[level_index], query_pos=query_embed ) output = self.transformer_self_attention_layers[i]( output, tgt_mask=None, tgt_key_padding_mask=None, query_pos=query_embed ) # FFN output = self.transformer_ffn_layers[i]( output ) outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], text_classifier=text_classifier, num_templates=num_templates) predictions_class.append(outputs_class) predictions_mask.append(outputs_mask) assert len(predictions_class) == self.num_layers + 1 out = { 'pred_logits': predictions_class[-1], 'pred_masks': predictions_mask[-1], 'aux_outputs': self._set_aux_loss( predictions_class if self.mask_classification else None, predictions_mask ) } return out def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, text_classifier, num_templates): decoder_output = self.decoder_norm(output) decoder_output = decoder_output.transpose(0, 1) mask_embed = self.mask_embed(decoder_output) outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) # fcclip head maskpool_embeddings = self.mask_pooling(x=mask_features, mask=outputs_mask) # [B, Q, C] maskpool_embeddings = self._mask_pooling_proj(maskpool_embeddings) class_embed = self.class_embed(maskpool_embeddings + decoder_output) outputs_class = get_classification_logits(class_embed, text_classifier, self.logit_scale, num_templates) # NOTE: prediction is of higher-resolution # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW] attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False) # must use bool type # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() attn_mask = attn_mask.detach() return outputs_class, outputs_mask, attn_mask @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_seg_masks): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. if self.mask_classification: return [ {"pred_logits": a, "pred_masks": b} for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) ] else: return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]