Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from maskrcnn_benchmark.layers import Conv2d, _NewEmptyTensorOp | |
from maskrcnn_benchmark.layers import ConvTranspose2d | |
from ...utils import permute_and_flatten | |
class MaskRCNNC4Predictor(nn.Module): | |
def __init__(self, cfg): | |
super(MaskRCNNC4Predictor, self).__init__() | |
# TODO: a hack for binary mask head | |
# num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES | |
num_classes = 2 | |
dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] | |
if cfg.MODEL.ROI_HEADS.USE_FPN: | |
num_inputs = dim_reduced | |
else: | |
stage_index = 4 | |
stage2_relative_factor = 2 ** (stage_index - 1) | |
res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS | |
num_inputs = res2_out_channels * stage2_relative_factor | |
self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) | |
self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) | |
for name, param in self.named_parameters(): | |
if "bias" in name: | |
nn.init.constant_(param, 0) | |
elif "weight" in name: | |
# Caffe2 implementation uses MSRAFill, which in fact | |
# corresponds to kaiming_normal_ in PyTorch | |
nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") | |
def forward(self, x): | |
x = F.relu(self.conv5_mask(x)) | |
return self.mask_fcn_logits(x) | |
class VLMaskRCNNC4Predictor(nn.Module): | |
def __init__(self, cfg): | |
super(VLMaskRCNNC4Predictor, self).__init__() | |
dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] | |
if cfg.MODEL.ROI_HEADS.USE_FPN: | |
num_inputs = dim_reduced | |
else: | |
stage_index = 4 | |
stage2_relative_factor = 2 ** (stage_index - 1) | |
res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS | |
num_inputs = res2_out_channels * stage2_relative_factor | |
self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) | |
# self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) | |
log_scale = cfg.MODEL.DYHEAD.LOG_SCALE | |
self.out_dim = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN | |
self.dot_product_projection_image = nn.Identity() | |
self.dot_product_projection_text = nn.Linear(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, dim_reduced, bias=True) | |
self.log_scale = nn.Parameter(torch.Tensor([log_scale]), requires_grad=True) | |
self.bias_lang = nn.Parameter(torch.zeros(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM), requires_grad=True) | |
for name, param in self.named_parameters(): | |
if "bias" in name: | |
nn.init.constant_(param, 0) | |
elif "weight" in name: | |
# Caffe2 implementation uses MSRAFill, which in fact | |
# corresponds to kaiming_normal_ in PyTorch | |
nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") | |
def forward(self, x, language_dict_features): | |
x = F.relu(self.conv5_mask(x)) | |
if x.numel() <= 0: | |
output_shape = [x.shape[0], self.out_dim] + list(x.shape[-2:]) | |
return _NewEmptyTensorOp.apply(x, output_shape) | |
embedding = language_dict_features["hidden"] | |
# norm | |
embedding = F.normalize(embedding, p=2, dim=-1) | |
dot_product_proj_tokens = self.dot_product_projection_text(embedding / 2.0) | |
dot_product_proj_tokens_bias = torch.matmul(embedding, self.bias_lang) | |
B, C, H, W = x.shape | |
# add bias (language) | |
dot_product_proj_queries = self.dot_product_projection_image(x) | |
dot_product_proj_queries = permute_and_flatten(dot_product_proj_queries, B, -1, C, H, W) | |
A = dot_product_proj_queries.shape[1] | |
bias = dot_product_proj_tokens_bias.unsqueeze(1).repeat(1, A, 1) | |
# dot product | |
dot_product_logit = ( | |
torch.matmul(dot_product_proj_queries, dot_product_proj_tokens.transpose(-1, -2)) / self.log_scale.exp() | |
) + bias | |
# clamp for stability | |
dot_product_logit = torch.clamp(dot_product_logit, max=50000) | |
dot_product_logit = torch.clamp(dot_product_logit, min=-50000) | |
dot_product_logit = dot_product_logit.view(B, H, W, self.out_dim).permute(0, 3, 1, 2) | |
return dot_product_logit | |
_ROI_MASK_PREDICTOR = {"MaskRCNNC4Predictor": MaskRCNNC4Predictor, "VLMaskRCNNC4Predictor": VLMaskRCNNC4Predictor} | |
def make_roi_mask_predictor(cfg): | |
func = _ROI_MASK_PREDICTOR[cfg.MODEL.ROI_MASK_HEAD.PREDICTOR] | |
return func(cfg) | |