# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch from torch import nn from torch.nn import functional as F from maskrcnn_benchmark.layers import Conv2d, _NewEmptyTensorOp from maskrcnn_benchmark.layers import ConvTranspose2d from ...utils import permute_and_flatten class MaskRCNNC4Predictor(nn.Module): def __init__(self, cfg): super(MaskRCNNC4Predictor, self).__init__() # TODO: a hack for binary mask head # num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES num_classes = 2 dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] if cfg.MODEL.ROI_HEADS.USE_FPN: num_inputs = dim_reduced else: stage_index = 4 stage2_relative_factor = 2 ** (stage_index - 1) res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS num_inputs = res2_out_channels * stage2_relative_factor self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) for name, param in self.named_parameters(): if "bias" in name: nn.init.constant_(param, 0) elif "weight" in name: # Caffe2 implementation uses MSRAFill, which in fact # corresponds to kaiming_normal_ in PyTorch nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") def forward(self, x): x = F.relu(self.conv5_mask(x)) return self.mask_fcn_logits(x) class VLMaskRCNNC4Predictor(nn.Module): def __init__(self, cfg): super(VLMaskRCNNC4Predictor, self).__init__() dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] if cfg.MODEL.ROI_HEADS.USE_FPN: num_inputs = dim_reduced else: stage_index = 4 stage2_relative_factor = 2 ** (stage_index - 1) res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS num_inputs = res2_out_channels * stage2_relative_factor self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) # self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) log_scale = cfg.MODEL.DYHEAD.LOG_SCALE self.out_dim = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN self.dot_product_projection_image = nn.Identity() self.dot_product_projection_text = nn.Linear(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, dim_reduced, bias=True) self.log_scale = nn.Parameter(torch.Tensor([log_scale]), requires_grad=True) self.bias_lang = nn.Parameter(torch.zeros(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM), requires_grad=True) for name, param in self.named_parameters(): if "bias" in name: nn.init.constant_(param, 0) elif "weight" in name: # Caffe2 implementation uses MSRAFill, which in fact # corresponds to kaiming_normal_ in PyTorch nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") def forward(self, x, language_dict_features): x = F.relu(self.conv5_mask(x)) if x.numel() <= 0: output_shape = [x.shape[0], self.out_dim] + list(x.shape[-2:]) return _NewEmptyTensorOp.apply(x, output_shape) embedding = language_dict_features["hidden"] # norm embedding = F.normalize(embedding, p=2, dim=-1) dot_product_proj_tokens = self.dot_product_projection_text(embedding / 2.0) dot_product_proj_tokens_bias = torch.matmul(embedding, self.bias_lang) B, C, H, W = x.shape # add bias (language) dot_product_proj_queries = self.dot_product_projection_image(x) dot_product_proj_queries = permute_and_flatten(dot_product_proj_queries, B, -1, C, H, W) A = dot_product_proj_queries.shape[1] bias = dot_product_proj_tokens_bias.unsqueeze(1).repeat(1, A, 1) # dot product dot_product_logit = ( torch.matmul(dot_product_proj_queries, dot_product_proj_tokens.transpose(-1, -2)) / self.log_scale.exp() ) + bias # clamp for stability dot_product_logit = torch.clamp(dot_product_logit, max=50000) dot_product_logit = torch.clamp(dot_product_logit, min=-50000) dot_product_logit = dot_product_logit.view(B, H, W, self.out_dim).permute(0, 3, 1, 2) return dot_product_logit _ROI_MASK_PREDICTOR = {"MaskRCNNC4Predictor": MaskRCNNC4Predictor, "VLMaskRCNNC4Predictor": VLMaskRCNNC4Predictor} def make_roi_mask_predictor(cfg): func = _ROI_MASK_PREDICTOR[cfg.MODEL.ROI_MASK_HEAD.PREDICTOR] return func(cfg)