zdou0830's picture
history blame
No virus
4.79 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch import nn
from torch.nn import functional as F
from maskrcnn_benchmark.layers import Conv2d, _NewEmptyTensorOp
from maskrcnn_benchmark.layers import ConvTranspose2d
from ...utils import permute_and_flatten
class MaskRCNNC4Predictor(nn.Module):
def __init__(self, cfg):
super(MaskRCNNC4Predictor, self).__init__()
# TODO: a hack for binary mask head
# num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES
num_classes = 2
dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1]
num_inputs = dim_reduced
stage_index = 4
stage2_relative_factor = 2 ** (stage_index - 1)
res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
num_inputs = res2_out_channels * stage2_relative_factor
self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0)
self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0)
for name, param in self.named_parameters():
if "bias" in name:
nn.init.constant_(param, 0)
elif "weight" in name:
# Caffe2 implementation uses MSRAFill, which in fact
# corresponds to kaiming_normal_ in PyTorch
nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
def forward(self, x):
x = F.relu(self.conv5_mask(x))
return self.mask_fcn_logits(x)
class VLMaskRCNNC4Predictor(nn.Module):
def __init__(self, cfg):
super(VLMaskRCNNC4Predictor, self).__init__()
dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1]
num_inputs = dim_reduced
stage_index = 4
stage2_relative_factor = 2 ** (stage_index - 1)
res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
num_inputs = res2_out_channels * stage2_relative_factor
self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0)
# self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0)
log_scale = cfg.MODEL.DYHEAD.LOG_SCALE
self.dot_product_projection_image = nn.Identity()
self.dot_product_projection_text = nn.Linear(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, dim_reduced, bias=True)
self.log_scale = nn.Parameter(torch.Tensor([log_scale]), requires_grad=True)
self.bias_lang = nn.Parameter(torch.zeros(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM), requires_grad=True)
for name, param in self.named_parameters():
if "bias" in name:
nn.init.constant_(param, 0)
elif "weight" in name:
# Caffe2 implementation uses MSRAFill, which in fact
# corresponds to kaiming_normal_ in PyTorch
nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
def forward(self, x, language_dict_features):
x = F.relu(self.conv5_mask(x))
if x.numel() <= 0:
output_shape = [x.shape[0], self.out_dim] + list(x.shape[-2:])
return _NewEmptyTensorOp.apply(x, output_shape)
embedding = language_dict_features["hidden"]
# norm
embedding = F.normalize(embedding, p=2, dim=-1)
dot_product_proj_tokens = self.dot_product_projection_text(embedding / 2.0)
dot_product_proj_tokens_bias = torch.matmul(embedding, self.bias_lang)
B, C, H, W = x.shape
# add bias (language)
dot_product_proj_queries = self.dot_product_projection_image(x)
dot_product_proj_queries = permute_and_flatten(dot_product_proj_queries, B, -1, C, H, W)
A = dot_product_proj_queries.shape[1]
bias = dot_product_proj_tokens_bias.unsqueeze(1).repeat(1, A, 1)
# dot product
dot_product_logit = (
torch.matmul(dot_product_proj_queries, dot_product_proj_tokens.transpose(-1, -2)) / self.log_scale.exp()
) + bias
# clamp for stability
dot_product_logit = torch.clamp(dot_product_logit, max=50000)
dot_product_logit = torch.clamp(dot_product_logit, min=-50000)
dot_product_logit = dot_product_logit.view(B, H, W, self.out_dim).permute(0, 3, 1, 2)
return dot_product_logit
_ROI_MASK_PREDICTOR = {"MaskRCNNC4Predictor": MaskRCNNC4Predictor, "VLMaskRCNNC4Predictor": VLMaskRCNNC4Predictor}
def make_roi_mask_predictor(cfg):
return func(cfg)