# Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py # Copyright (c) Meta Platforms, Inc. All Rights Reserved import fvcore.nn.weight_init as weight_init import torch from torch import nn from torch.nn import functional as F from detectron2.config import configurable from detectron2.layers import Conv2d from .position_encoding import PositionEmbeddingSine from .transformer import Transformer class TransformerPredictor(nn.Module): @configurable def __init__( self, in_channels, mask_classification=True, *, num_classes: int, hidden_dim: int, num_queries: int, nheads: int, dropout: float, dim_feedforward: int, enc_layers: int, dec_layers: int, pre_norm: bool, deep_supervision: bool, mask_dim: int, enforce_input_project: bool, ): """ NOTE: this interface is experimental. Args: in_channels: channels of the input features mask_classification: whether to add mask classifier or not num_classes: number of classes hidden_dim: Transformer feature dimension num_queries: number of queries nheads: number of heads dropout: dropout in Transformer dim_feedforward: feature dimension in feedforward network enc_layers: number of Transformer encoder layers dec_layers: number of Transformer decoder layers pre_norm: whether to use pre-LayerNorm or not deep_supervision: whether to add supervision to every decoder layers mask_dim: mask feature dimension enforce_input_project: add input project 1x1 conv even if input channels and hidden dim is identical """ super().__init__() self.mask_classification = mask_classification # positional encoding N_steps = hidden_dim // 2 self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) transformer = Transformer( d_model=hidden_dim, dropout=dropout, nhead=nheads, dim_feedforward=dim_feedforward, num_encoder_layers=enc_layers, num_decoder_layers=dec_layers, normalize_before=pre_norm, return_intermediate_dec=deep_supervision, ) self.num_queries = num_queries self.transformer = transformer hidden_dim = transformer.d_model self.query_embed = nn.Embedding(num_queries, hidden_dim) if in_channels != hidden_dim or enforce_input_project: self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1) weight_init.c2_xavier_fill(self.input_proj) else: self.input_proj = nn.Sequential() self.aux_loss = deep_supervision # output FFNs if self.mask_classification: self.class_embed = nn.Linear(hidden_dim, num_classes + 1) self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) @classmethod def from_config(cls, cfg, in_channels, mask_classification): ret = {} ret["in_channels"] = in_channels ret["mask_classification"] = mask_classification ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES # Transformer parameters: ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM return ret def forward(self, x, mask_features): pos = self.pe_layer(x) src = x mask = None hs, memory = self.transformer( self.input_proj(src), mask, self.query_embed.weight, pos ) if self.mask_classification: outputs_class = self.class_embed(hs) out = {"pred_logits": outputs_class[-1]} else: out = {} if self.aux_loss: # [l, bs, queries, embed] mask_embed = self.mask_embed(hs) outputs_seg_masks = torch.einsum( "lbqc,bchw->lbqhw", mask_embed, mask_features ) out["pred_masks"] = outputs_seg_masks[-1] out["aux_outputs"] = self._set_aux_loss( outputs_class if self.mask_classification else None, outputs_seg_masks ) else: # FIXME h_boxes takes the last one computed, keep this in mind # [bs, queries, embed] mask_embed = self.mask_embed(hs[-1]) outputs_seg_masks = torch.einsum( "bqc,bchw->bqhw", mask_embed, mask_features ) out["pred_masks"] = outputs_seg_masks return out @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_seg_masks): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. if self.mask_classification: return [ {"pred_logits": a, "pred_masks": b} for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) ] else: return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] class MLP(nn.Module): """Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList( nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) ) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x