# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Meta Platforms, Inc. All Rights Reserved import logging from typing import Callable, Dict, List, Optional, Tuple, Union import fvcore.nn.weight_init as weight_init from torch import nn from torch.nn import functional as F from detectron2.config import configurable from detectron2.layers import Conv2d, ShapeSpec, get_norm from detectron2.modeling import SEM_SEG_HEADS_REGISTRY from ..transformer.position_encoding import PositionEmbeddingSine from ..transformer.transformer import TransformerEncoder, TransformerEncoderLayer def build_pixel_decoder(cfg, input_shape): """ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`. """ name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape) forward_features = getattr(model, "forward_features", None) if not callable(forward_features): raise ValueError( "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. " f"Please implement forward_features for {name} to only return mask features." ) return model @SEM_SEG_HEADS_REGISTRY.register() class BasePixelDecoder(nn.Module): @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, conv_dim: int, mask_dim: int, norm: Optional[Union[str, Callable]] = None, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features conv_dims: number of output channels for the intermediate conv layers. mask_dim: number of output channels for the final conv layer. norm (str or callable): normalization for all conv layers """ super().__init__() input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" feature_channels = [v.channels for k, v in input_shape] lateral_convs = [] output_convs = [] use_bias = norm == "" for idx, in_channels in enumerate(feature_channels): if idx == len(self.in_features) - 1: output_norm = get_norm(norm, conv_dim) output_conv = Conv2d( in_channels, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(output_conv) self.add_module("layer_{}".format(idx + 1), output_conv) lateral_convs.append(None) output_convs.append(output_conv) else: lateral_norm = get_norm(norm, conv_dim) output_norm = get_norm(norm, conv_dim) lateral_conv = Conv2d( in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm, ) output_conv = Conv2d( conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(lateral_conv) weight_init.c2_xavier_fill(output_conv) self.add_module("adapter_{}".format(idx + 1), lateral_conv) self.add_module("layer_{}".format(idx + 1), output_conv) lateral_convs.append(lateral_conv) output_convs.append(output_conv) # Place convs into top-down order (from low to high resolution) # to make the top-down computation in forward clearer. self.lateral_convs = lateral_convs[::-1] self.output_convs = output_convs[::-1] self.mask_dim = mask_dim self.mask_features = Conv2d( conv_dim, mask_dim, kernel_size=3, stride=1, padding=1, ) weight_init.c2_xavier_fill(self.mask_features) @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): ret = {} ret["input_shape"] = { k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES } ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM return ret def forward_features(self, features): # Reverse feature maps into top-down order (from low to high resolution) for idx, f in enumerate(self.in_features[::-1]): x = features[f] lateral_conv = self.lateral_convs[idx] output_conv = self.output_convs[idx] if lateral_conv is None: y = output_conv(x) else: cur_fpn = lateral_conv(x) # Following FPN implementation, we use nearest upsampling here y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") y = output_conv(y) return self.mask_features(y), None def forward(self, features, targets=None): logger = logging.getLogger(__name__) logger.warning( "Calling forward() may cause unpredicted behavior of PixelDecoder module." ) return self.forward_features(features) class TransformerEncoderOnly(nn.Module): def __init__( self, d_model=512, nhead=8, num_encoder_layers=6, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, ): super().__init__() encoder_layer = TransformerEncoderLayer( d_model, nhead, dim_feedforward, dropout, activation, normalize_before ) encoder_norm = nn.LayerNorm(d_model) if normalize_before else None self.encoder = TransformerEncoder( encoder_layer, num_encoder_layers, encoder_norm ) self._reset_parameters() self.d_model = d_model self.nhead = nhead def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, src, mask, pos_embed): # flatten NxCxHxW to HWxNxC bs, c, h, w = src.shape src = src.flatten(2).permute(2, 0, 1) pos_embed = pos_embed.flatten(2).permute(2, 0, 1) if mask is not None: mask = mask.flatten(1) memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) return memory.permute(1, 2, 0).view(bs, c, h, w) @SEM_SEG_HEADS_REGISTRY.register() class TransformerEncoderPixelDecoder(BasePixelDecoder): @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, transformer_dropout: float, transformer_nheads: int, transformer_dim_feedforward: int, transformer_enc_layers: int, transformer_pre_norm: bool, conv_dim: int, mask_dim: int, norm: Optional[Union[str, Callable]] = None, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features transformer_dropout: dropout probability in transformer transformer_nheads: number of heads in transformer transformer_dim_feedforward: dimension of feedforward network transformer_enc_layers: number of transformer encoder layers transformer_pre_norm: whether to use pre-layernorm or not conv_dims: number of output channels for the intermediate conv layers. mask_dim: number of output channels for the final conv layer. norm (str or callable): normalization for all conv layers """ super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm) input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" feature_strides = [v.stride for k, v in input_shape] feature_channels = [v.channels for k, v in input_shape] in_channels = feature_channels[len(self.in_features) - 1] self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1) weight_init.c2_xavier_fill(self.input_proj) self.transformer = TransformerEncoderOnly( d_model=conv_dim, dropout=transformer_dropout, nhead=transformer_nheads, dim_feedforward=transformer_dim_feedforward, num_encoder_layers=transformer_enc_layers, normalize_before=transformer_pre_norm, ) N_steps = conv_dim // 2 self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) # update layer use_bias = norm == "" output_norm = get_norm(norm, conv_dim) output_conv = Conv2d( conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) weight_init.c2_xavier_fill(output_conv) delattr(self, "layer_{}".format(len(self.in_features))) self.add_module("layer_{}".format(len(self.in_features)), output_conv) self.output_convs[0] = output_conv @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): ret = super().from_config(cfg, input_shape) ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD ret[ "transformer_enc_layers" ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config ret["transformer_pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM return ret def forward_features(self, features): # Reverse feature maps into top-down order (from low to high resolution) for idx, f in enumerate(self.in_features[::-1]): x = features[f] lateral_conv = self.lateral_convs[idx] output_conv = self.output_convs[idx] if lateral_conv is None: transformer = self.input_proj(x) pos = self.pe_layer(x) transformer = self.transformer(transformer, None, pos) y = output_conv(transformer) # save intermediate feature as input to Transformer decoder transformer_encoder_features = transformer else: cur_fpn = lateral_conv(x) # Following FPN implementation, we use nearest upsampling here y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") y = output_conv(y) return self.mask_features(y), transformer_encoder_features def forward(self, features, targets=None): logger = logging.getLogger(__name__) logger.warning( "Calling forward() may cause unpredicted behavior of PixelDecoder module." ) return self.forward_features(features)