zdou0830's picture
desco
749745d
raw
history blame
No virus
48.6 kB
import torch
import torch.nn.functional as F
from torch import nn
from collections import defaultdict
from .inference import make_atss_postprocessor
from .loss import make_atss_loss_evaluator
from .anchor_generator import make_anchor_generator_complex
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
from maskrcnn_benchmark.layers import Scale, DYReLU, SELayer, ModulatedDeformConv
from maskrcnn_benchmark.layers import NaiveSyncBatchNorm2d, FrozenBatchNorm2d
from maskrcnn_benchmark.modeling.backbone.fbnet import *
from maskrcnn_benchmark.engine.inference import create_positive_map_label_to_token_from_positive_map
from ..utils import cat, concat_box_prediction_layers, permute_and_flatten
from maskrcnn_benchmark.utils.fuse_helper import (
FeatureResizer,
func_attention,
_make_mlp,
_make_conv,
_make_coord,
BiAttentionBlock,
AttentionT2I,
BiAttentionBlockForCheckpoint,
BertLMPredictionHead,
)
from transformers.models.bert.modeling_bert import (
BertConfig,
BertAttention,
BertIntermediate,
BertOutput,
BertPreTrainedModel,
)
from transformers.models.roberta.configuration_roberta import RobertaConfig
from transformers.modeling_utils import apply_chunking_to_forward
import torch.utils.checkpoint as checkpoint
import pdb
from maskrcnn_benchmark.modeling.language_backbone.clip_model import QuickGELU, LayerNorm, DropPath
from timm.models.layers import DropPath, trunc_normal_
class h_sigmoid(nn.Module):
def __init__(self, inplace=True, h_max=1):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)
self.h_max = h_max
def forward(self, x):
return self.relu(x + 3) * self.h_max / 6
class BoxCoder(object):
def __init__(self, cfg):
self.cfg = cfg
def encode(self, gt_boxes, anchors):
TO_REMOVE = 1 # TODO remove
ex_widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE
ex_heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE
ex_ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2
ex_ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2
gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0] + TO_REMOVE
gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1] + TO_REMOVE
gt_ctr_x = (gt_boxes[:, 2] + gt_boxes[:, 0]) / 2
gt_ctr_y = (gt_boxes[:, 3] + gt_boxes[:, 1]) / 2
wx, wy, ww, wh = (10.0, 10.0, 5.0, 5.0)
if gt_ctr_x.nelement() == 0:
targets_dx = torch.zeros_like(ex_ctr_x)
targets_dy = torch.zeros_like(ex_ctr_y)
targets_dw = torch.zeros_like(ex_widths)
targets_dh = torch.zeros_like(ex_heights)
else:
targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
targets_dw = ww * torch.log(gt_widths / ex_widths)
targets_dh = wh * torch.log(gt_heights / ex_heights)
targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
return targets
def decode(self, preds, anchors):
anchors = anchors.to(preds.dtype)
TO_REMOVE = 1 # TODO remove
widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE
heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE
ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2
ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2
wx, wy, ww, wh = (10.0, 10.0, 5.0, 5.0)
dx = preds[:, 0::4] / wx
dy = preds[:, 1::4] / wy
dw = preds[:, 2::4] / ww
dh = preds[:, 3::4] / wh
# Prevent sending too large values into torch.exp()
dw = torch.clamp(dw, max=math.log(1000.0 / 16))
dh = torch.clamp(dh, max=math.log(1000.0 / 16))
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
pred_w = torch.exp(dw) * widths[:, None]
pred_h = torch.exp(dh) * heights[:, None]
pred_boxes = torch.zeros_like(preds)
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * (pred_w - 1)
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * (pred_h - 1)
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * (pred_w - 1)
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * (pred_h - 1)
return pred_boxes
class Conv3x3Norm(torch.nn.Module):
def __init__(self, in_channels, out_channels, stride, groups=1, deformable=False, bn_type=None):
super(Conv3x3Norm, self).__init__()
if deformable:
self.conv = ModulatedDeformConv(
in_channels, out_channels, kernel_size=3, stride=stride, padding=1, groups=groups
)
else:
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, groups=groups)
if isinstance(bn_type, (list, tuple)):
assert len(bn_type) == 2
assert bn_type[0] == "gn"
gn_group = bn_type[1]
bn_type = bn_type[0]
if bn_type == "bn":
bn_op = nn.BatchNorm2d(out_channels)
elif bn_type == "sbn":
bn_op = nn.SyncBatchNorm(out_channels)
elif bn_type == "nsbn":
bn_op = NaiveSyncBatchNorm2d(out_channels)
elif bn_type == "gn":
bn_op = nn.GroupNorm(num_groups=gn_group, num_channels=out_channels)
elif bn_type == "af":
bn_op = FrozenBatchNorm2d(out_channels)
if bn_type is not None:
self.bn = bn_op
else:
self.bn = None
def forward(self, input, **kwargs):
x = self.conv(input, **kwargs)
if self.bn:
x = self.bn(x)
return x
class DyConv(torch.nn.Module):
def __init__(
self,
in_channels=256,
out_channels=256,
conv_func=nn.Conv2d,
use_dyfuse=True,
use_dyrelu=False,
use_deform=False,
):
super(DyConv, self).__init__()
self.DyConv = nn.ModuleList()
self.DyConv.append(conv_func(in_channels, out_channels, 1))
self.DyConv.append(conv_func(in_channels, out_channels, 1))
self.DyConv.append(conv_func(in_channels, out_channels, 2))
if use_dyfuse:
self.AttnConv = nn.Sequential(
nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, 1, kernel_size=1), nn.ReLU(inplace=True)
)
self.h_sigmoid = h_sigmoid()
else:
self.AttnConv = None
if use_dyrelu:
self.relu = DYReLU(in_channels, out_channels)
else:
self.relu = nn.ReLU()
if use_deform:
self.offset = nn.Conv2d(in_channels, 27, kernel_size=3, stride=1, padding=1)
else:
self.offset = None
self.init_weights()
def init_weights(self):
for m in self.DyConv.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, 0, 0.01)
if m.bias is not None:
m.bias.data.zero_()
if self.AttnConv is not None:
for m in self.AttnConv.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, 0, 0.01)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, inputs):
visual_feats = inputs["visual"]
language_dict_features = inputs["lang"]
next_x = []
for level, feature in enumerate(visual_feats):
conv_args = dict()
if self.offset is not None:
offset_mask = self.offset(feature)
offset = offset_mask[:, :18, :, :]
mask = offset_mask[:, 18:, :, :].sigmoid()
conv_args = dict(offset=offset, mask=mask)
temp_fea = [self.DyConv[1](feature, **conv_args)]
if level > 0:
temp_fea.append(self.DyConv[2](visual_feats[level - 1], **conv_args))
if level < len(visual_feats) - 1:
temp_fea.append(
F.upsample_bilinear(
self.DyConv[0](visual_feats[level + 1], **conv_args), size=[feature.size(2), feature.size(3)]
)
)
mean_fea = torch.mean(torch.stack(temp_fea), dim=0, keepdim=False)
if self.AttnConv is not None:
attn_fea = []
res_fea = []
for fea in temp_fea:
res_fea.append(fea)
attn_fea.append(self.AttnConv(fea))
res_fea = torch.stack(res_fea)
spa_pyr_attn = self.h_sigmoid(torch.stack(attn_fea))
mean_fea = torch.mean(res_fea * spa_pyr_attn, dim=0, keepdim=False)
next_x.append(mean_fea)
next_x = [self.relu(item) for item in next_x]
features_dict = {"visual": next_x, "lang": language_dict_features}
return features_dict
class BertEncoderLayer(BertPreTrainedModel):
def __init__(self, config, clamp_min_for_underflow=False, clamp_max_for_overflow=False):
super().__init__(config)
self.config = config
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
from maskrcnn_benchmark.modeling.rpn.modeling_bert import BertAttention, BertIntermediate, BertOutput
self.attention = BertAttention(config, clamp_min_for_underflow, clamp_max_for_overflow)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, inputs):
language_dict_features = inputs["lang"]
hidden_states = language_dict_features["hidden"]
attention_mask = language_dict_features["masks"]
device = hidden_states.device
input_shape = hidden_states.size()[:-1]
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
self_attention_outputs = self.attention(
hidden_states,
extended_attention_mask,
None,
output_attentions=False,
past_key_value=None,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
hidden_states = outputs[0]
language_dict_features["hidden"] = hidden_states
features_dict = {"visual": inputs["visual"], "lang": language_dict_features}
return features_dict
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class CLIPTransformerLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
d_model = self.config.MODEL.CLIP.WIDTH
n_head = self.config.MODEL.CLIP.HEADS
drop_path = self.config.MODEL.CLIP.DROP_PATH
self.context_length = self.config.MODEL.CLIP.CONTEXT_LENGTH
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model)),
]
)
)
self.ln_2 = LayerNorm(d_model)
self.attn_mask = None
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0)
def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, key_padding_mask=key_padding_mask)[0]
def forward(self, inputs):
language_dict_features = inputs["lang"]
x = language_dict_features["hidden"]
mask = language_dict_features["masks"]
# get extended attention mask for nn.MultiHeadAttention
key_padding_mask = (1.0 - mask).to(torch.bool)
x = x.permute(1, 0, 2)
x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))
x = x + self.drop_path(self.mlp(self.ln_2(x)))
x = x.permute(1, 0, 2)
language_dict_features["hidden"] = x
features_dict = {"visual": inputs["visual"], "lang": language_dict_features}
return features_dict
class DummyLayer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, inputs):
return inputs
class VLFuse(torch.nn.Module):
"""
Early Fusion Module
"""
def __init__(self, cfg):
super(VLFuse, self).__init__()
self.init_configs(cfg)
self.cfg = cfg
self.use_checkpoint = False
if hasattr(cfg.MODEL.DYHEAD, "USE_CHECKPOINT"):
self.use_checkpoint = cfg.MODEL.DYHEAD.USE_CHECKPOINT
self.dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True)
# early fusion module
print("EARLY FUSION ON, USING {}".format(cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE))
if cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "MHA-S":
# single-direction (text->image)
# text -> image
self.t2i_attn = AttentionT2I(
q_dim=self.joint_embedding_size,
k_dim=self.lang_dim,
embed_dim=self.embed_dim,
num_heads=self.n_head,
hidden_dim=self.t2i_hidden_dim,
dropout=0.1,
drop_path=0.0,
init_values=1.0 / cfg.MODEL.DYHEAD.NUM_CONVS,
mode="t2i",
use_layer_scale=cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_LAYER_SCALE,
clamp_min_for_underflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MIN_FOR_UNDERFLOW,
clamp_max_for_overflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MAX_FOR_OVERFLOW,
)
elif cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "MHA-B":
# bi-direction (text->image, image->text)
self.b_attn = BiAttentionBlockForCheckpoint(
v_dim=self.joint_embedding_size,
l_dim=self.lang_dim,
embed_dim=self.embed_dim,
num_heads=self.n_head,
hidden_dim=self.i2t_hidden_dim,
dropout=0.1,
drop_path=0.0,
init_values=1.0 / cfg.MODEL.DYHEAD.NUM_CONVS,
cfg=cfg,
)
if (
self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SEPARATE_BIDIRECTIONAL
and self.cfg.MODEL.DYHEAD.FUSE_CONFIG.DO_LANG_PROJ_OUTSIDE_CHECKPOINT
):
self.shrink_lang = FeatureResizer(self.lang_dim * 5, self.lang_dim, 0.1)
elif cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "SCAN":
# single-direction (text->image)
self.mapping_lang = _make_mlp(self.lang_dim, self.joint_embedding_size, self.joint_embedding_dropout)
self.joint_fusion = nn.ModuleList([_make_conv(self.joint_inp_dim, self.joint_out_dim, 1) for _ in range(5)])
elif cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "FILM":
# single-direction (text->image)
self.mapping_lang = _make_mlp(self.lang_dim, self.joint_embedding_size, self.joint_embedding_dropout)
self.gamma = nn.ModuleList(nn.Linear(self.joint_embedding_size, self.joint_inp_dim) for _ in range(5))
self.beta = nn.ModuleList(nn.Linear(self.joint_embedding_size, self.joint_inp_dim) for _ in range(5))
self.joint_fusion = nn.ModuleList([_make_conv(self.joint_inp_dim, self.joint_out_dim, 1) for _ in range(5)])
else:
print("NO FUSION INVOLVED.")
def init_configs(self, cfg):
# common params
self.lang_model = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE
self.joint_embedding_size = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_SIZE
self.joint_embedding_dropout = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_DROPOUT
self.joint_mlp_layers = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_MLP_LAYERS
self.max_query_len = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN
self.n_layers = cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS
self.coord_dim = 8
self.joint_inp_dim = self.coord_dim + self.joint_embedding_size
self.joint_out_dim = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_OUT_SIZE
# mha params
self.n_head = 8
self.embed_dim = 2048
self.t2i_hidden_dim = 1024 # 256 * 4
self.i2t_hidden_dim = 3072 # 768 * 4
if self.lang_model in ["bert-base-uncased", "roberta-base", "clip", "roberta-fused", "roberta-fused-v2", "roberta-fused-tiny"]:
self.lang_dim = cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM
else:
self.lang_dim = 1024
def forward(self, x):
visual_features = x["visual"]
language_dict_features = x["lang"]
batch_size = visual_features[0].shape[0]
device = visual_features[0].device
fused_visual_features = None
fused_language_dict_features = None
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "MHA-S":
language_feature = language_dict_features["hidden"]
mask = language_dict_features["masks"]
# text -> image
if self.use_checkpoint:
q0, q1, q2, q3, q4 = checkpoint.checkpoint(
self.t2i_attn,
visual_features[0],
visual_features[1],
visual_features[2],
visual_features[3],
visual_features[4],
language_feature,
language_feature,
mask,
self.dummy_tensor,
)
else:
q0, q1, q2, q3, q4 = self.t2i_attn(
visual_features[0],
visual_features[1],
visual_features[2],
visual_features[3],
visual_features[4],
language_feature,
language_feature,
attention_mask=mask,
)
fused_visual_features = [q0, q1, q2, q3, q4]
fused_language_dict_features = language_dict_features
elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "MHA-B":
if self.use_checkpoint:
q0, q1, q2, q3, q4, l0, l1, l2, l3, l4 = checkpoint.checkpoint(
self.b_attn,
visual_features[0],
visual_features[1],
visual_features[2],
visual_features[3],
visual_features[4],
language_dict_features["hidden"],
language_dict_features["masks"],
self.dummy_tensor,
)
else:
q0, q1, q2, q3, q4, l0, l1, l2, l3, l4 = self.b_attn(
visual_features[0],
visual_features[1],
visual_features[2],
visual_features[3],
visual_features[4],
language_dict_features["hidden"],
language_dict_features["masks"],
self.dummy_tensor,
)
fused_visual_features = [q0, q1, q2, q3, q4]
if (
self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SEPARATE_BIDIRECTIONAL
and self.cfg.MODEL.DYHEAD.FUSE_CONFIG.DO_LANG_PROJ_OUTSIDE_CHECKPOINT
):
language_features = self.shrink_lang(torch.cat([l0, l1, l2, l3, l4], dim=-1))
else:
language_features = l0
language_dict_features["hidden"] = language_features
fused_language_dict_features = language_dict_features
elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "SCAN":
# text -> image
language_feature = language_dict_features["aggregate"]
language_feature = self.mapping_lang(language_feature)
visu_feat = []
for ii, feat in enumerate(visual_features):
attn_feat = func_attention(feat, language_feature, smooth=1, raw_feature_norm="softmax")
visu_feat.append(attn_feat)
fused_visual_features = [fusion(feat) for feat, fusion in zip(visu_feat, self.joint_fusion)]
fused_language_dict_features = language_dict_features
elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "FILM":
# text -> image
# relative position embedding
coord_feats = [_make_coord(batch_size, x.shape[2], x.shape[3]) for x in visual_features]
# I only use a global representation of language
# you can also use more complex modeling using word-level representations
# Usage: lang_feat = lang_feat['words'] shape [seq_len, dim]
language_feature = language_dict_features["aggregate"]
language_feature = self.mapping_lang(language_feature)
# attention mechanism for fusion
gamma = [F.tanh(gamma(language_feature)) for gamma in self.gamma]
beta = [F.tanh(beta(language_feature)) for beta in self.beta]
visu_feat = []
for ii, feat in enumerate(visual_features):
coord_feat = coord_feats[ii].to(device)
feat = torch.cat([feat, coord_feat], dim=1)
b = beta[ii].view(batch_size, -1, 1, 1).expand_as(feat)
g = gamma[ii].view(batch_size, -1, 1, 1).expand_as(feat)
feat = F.relu(g * feat + b)
visu_feat.append(feat)
fused_visual_features = [fusion(feat) for feat, fusion in zip(visu_feat, self.joint_fusion)]
fused_language_dict_features = language_dict_features
else:
fused_visual_features = visual_features
fused_language_dict_features = language_dict_features
features_dict = {"visual": fused_visual_features, "lang": fused_language_dict_features}
return features_dict
class VLDyHead(torch.nn.Module):
def __init__(self, cfg):
super(VLDyHead, self).__init__()
self.cfg = cfg
# bert_cfg = BertConfig.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE)
if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE in ["bert-base-uncased", "roberta-base"]:
lang_cfg = BertConfig.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE)
elif cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "clip":
lang_cfg = cfg
elif cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE in ["roberta-fused", "roberta-fused-v2", "roberta-fused-tiny"]:
lang_cfg = RobertaConfig.from_pretrained("roberta-base")
else:
lang_cfg = None
raise NotImplementedError
num_classes = cfg.MODEL.DYHEAD.NUM_CLASSES - 1
num_tokens = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN
num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE
in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
channels = cfg.MODEL.DYHEAD.CHANNELS
if cfg.MODEL.DYHEAD.USE_GN:
bn_type = ["gn", cfg.MODEL.GROUP_NORM.NUM_GROUPS]
elif cfg.MODEL.DYHEAD.USE_NSYNCBN:
bn_type = "nsbn"
elif cfg.MODEL.DYHEAD.USE_SYNCBN:
bn_type = "sbn"
else:
bn_type = None
use_dyrelu = cfg.MODEL.DYHEAD.USE_DYRELU
use_dyfuse = cfg.MODEL.DYHEAD.USE_DYFUSE
use_deform = cfg.MODEL.DYHEAD.USE_DFCONV
if cfg.MODEL.DYHEAD.CONV_FUNC:
conv_func = lambda i, o, s: eval(cfg.MODEL.DYHEAD.CONV_FUNC)(i, o, s, bn_type=bn_type)
else:
conv_func = lambda i, o, s: Conv3x3Norm(i, o, s, deformable=use_deform, bn_type=bn_type)
dyhead_tower = []
for i in range(cfg.MODEL.DYHEAD.NUM_CONVS):
if cfg.MODEL.DYHEAD.FUSE_CONFIG.EARLY_FUSE_ON:
# cross-modality fusion
dyhead_tower.append(VLFuse(cfg))
# self language path
if i < cfg.MODEL.DYHEAD.NUM_CONVS - 1 or cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_FUSED_FEATURES_DOT_PRODUCT:
# dyhead_tower.append(
# BertEncoderLayer(
# bert_cfg,
# clamp_min_for_underflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MIN_FOR_UNDERFLOW,
# clamp_max_for_overflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MAX_FOR_OVERFLOW)
# )
if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE in [
"bert-base-uncased",
"roberta-fused",
"roberta-fused-v2",
"roberta-fused-tiny",
"roberta-base",
]:
dyhead_tower.append(
BertEncoderLayer(
lang_cfg,
clamp_min_for_underflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MIN_FOR_UNDERFLOW,
clamp_max_for_overflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MAX_FOR_OVERFLOW,
)
)
elif cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "clip":
dyhead_tower.append(CLIPTransformerLayer(lang_cfg))
else:
raise NotImplementedError
else:
dyhead_tower.append(DummyLayer())
# self vision path
dyhead_tower.append(
DyConv(
in_channels if i == 0 else channels,
channels,
conv_func=conv_func,
use_dyrelu=(use_dyrelu and in_channels == channels) if i == 0 else use_dyrelu,
use_dyfuse=(use_dyfuse and in_channels == channels) if i == 0 else use_dyfuse,
use_deform=(use_deform and in_channels == channels) if i == 0 else use_deform,
)
)
self.add_module("dyhead_tower", nn.Sequential(*dyhead_tower))
self.cls_logits = nn.Conv2d(channels, num_anchors * num_classes, kernel_size=1)
self.bbox_pred = nn.Conv2d(channels, num_anchors * 4, kernel_size=1)
self.centerness = nn.Conv2d(channels, num_anchors * 1, kernel_size=1)
# initialize the bias for focal loss
prior_prob = cfg.MODEL.DYHEAD.PRIOR_PROB
bias_value = -math.log((1 - prior_prob) / prior_prob)
log_scale = self.cfg.MODEL.DYHEAD.LOG_SCALE
# soft token head
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS:
self.token_logits = nn.Conv2d(channels, num_anchors * num_tokens, kernel_size=1)
# ABLATION
# self.token_logits = nn.Conv2d(channels, num_anchors * num_tokens, kernel_size=1, bias=False)
# self.bias = nn.Parameter(torch.zeros(channels), requires_grad=True)
# self.bias0 = nn.Parameter(torch.Tensor([bias_value]), requires_grad=True)
# contrastive alignment head
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS:
assert self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS == False
contrastive_hdim = cfg.MODEL.DYHEAD.FUSE_CONFIG.CONTRASTIVE_HIDDEN_DIM
self.contrastive_align_projection_image = nn.Conv2d(channels, num_anchors * contrastive_hdim, kernel_size=1)
self.contrastive_align_projection_text = nn.Linear(channels, contrastive_hdim, bias=True)
self.log_scale = nn.Parameter(torch.Tensor([log_scale]), requires_grad=True)
# dot product soft token head
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
assert self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS == False
self.dot_product_projection_image = nn.Identity()
self.dot_product_projection_text = nn.Linear(
self.cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, num_anchors * channels, bias=True
)
self.log_scale = nn.Parameter(torch.Tensor([log_scale]), requires_grad=True)
# DEBUG
# self.bias = nn.Parameter(torch.zeros(channels), requires_grad=True)
self.bias_lang = nn.Parameter(torch.zeros(self.cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM), requires_grad=True)
self.bias0 = nn.Parameter(torch.Tensor([bias_value]), requires_grad=True)
# initialization
for modules in [self.cls_logits, self.bbox_pred, self.centerness]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, 0)
self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)])
torch.nn.init.constant_(self.cls_logits.bias, bias_value)
# if use soft token loss
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS:
for modules in [self.token_logits]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, 0)
torch.nn.init.constant_(self.token_logits.bias, bias_value)
# print(torch.norm(self.token_logits.weight))
# if use contrastive loss
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS:
for modules in [self.contrastive_align_projection_image]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, 0)
# if use dot product token loss
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
for modules in [self.dot_product_projection_image]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, bias_value)
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "clip":
lang_cfg = BertConfig.from_pretrained("bert-base-uncased")
lang_cfg.hidden_size = cfg.MODEL.CLIP.WIDTH
lang_cfg.vocab_size = cfg.MODEL.CLIP.VOCAB_SIZE
self.mlm_head = BertLMPredictionHead(lang_cfg) # nn.Linear(hidden_size, config.vocab_size, bias=False)
def forward(self, x, language_dict_features=None, embedding=None, swint_feature_c4=None):
logits = []
bbox_reg = []
centerness = []
feat_inputs = {"visual": x, "lang": language_dict_features}
dyhead_tower = self.dyhead_tower(feat_inputs)
# soft token
t_logits = None
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS:
t_logits = []
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_FUSED_FEATURES_DOT_PRODUCT:
embedding = dyhead_tower["lang"]["hidden"]
# MLM loss
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
mlm_logits = self.mlm_head(embedding)
else:
mlm_logits = None
# contrastive
contrastive_logits = None
proj_tokens = None
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS:
contrastive_logits = []
# follow MDETR's way
proj_tokens = F.normalize(self.contrastive_align_projection_text(embedding), p=2, dim=-1)
# dot product soft token
dot_product_logits = None
dot_product_proj_tokens = None
dot_product_proj_tokens_bias = None
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
dot_product_logits = []
# norm
embedding = F.normalize(embedding, p=2, dim=-1)
dot_product_proj_tokens = self.dot_product_projection_text(embedding / 2.0)
# w/o norm
# dot_product_proj_tokens = self.dot_product_projection_text(embedding / 28.0)
dot_product_proj_tokens_bias = torch.matmul(embedding, self.bias_lang) + self.bias0
# shallow contrastive (original feature from image & text encoder)
shallow_img_emb_feats = None
shallow_text_emb = None
if (
self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS
or self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS
):
shallow_img_emb_feats = []
shallow_text_emb = embedding
# print([v.shape for v in x])
# shallow contrastive: use the feature from swint backbone
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS:
for b, feature in enumerate(swint_feature_c4):
# BF, CF, HF, WF = feat.shape
# shallow_img_emb = permute_and_flatten(feat, BF, -1, CF, HF, WF)
shallow_img_emb_feats.append(feature)
fused_visual_features = None
if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES:
fused_visual_features = []
# use the feature from FPN
for l, feature in enumerate(x):
logits.append(self.cls_logits(dyhead_tower["visual"][l]))
bbox_pred = self.scales[l](self.bbox_pred(dyhead_tower["visual"][l]))
bbox_reg.append(bbox_pred)
centerness.append(self.centerness(dyhead_tower["visual"][l]))
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS:
t_logits.append(self.token_logits(dyhead_tower["visual"][l]))
# ABLATION
# b = self.bias.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
# x = dyhead_tower["visual"][l]
# B, C, H, W = x.shape
# bias = b.repeat(B, 1, H, W)
# t_logits.append(self.token_logits(dyhead_tower["visual"][l] + bias) + self.bias0)
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS:
x = dyhead_tower["visual"][l]
B, _, H, W = x.shape
C = proj_tokens.shape[2]
proj_queries = self.contrastive_align_projection_image(dyhead_tower["visual"][l])
proj_queries = permute_and_flatten(proj_queries, B, -1, C, H, W)
normalized_img_emb = F.normalize(proj_queries, p=2, dim=-1)
normalized_text_emb = proj_tokens
contrastive_logit = (
torch.matmul(normalized_img_emb, normalized_text_emb.transpose(-1, -2)) / self.log_scale.exp()
)
contrastive_logits.append(contrastive_logit)
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
x = dyhead_tower["visual"][l]
if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES:
fused_visual_features.append(x)
B, C, H, W = x.shape
# add bias (language)
dot_product_proj_queries = self.dot_product_projection_image(x)
dot_product_proj_queries = permute_and_flatten(dot_product_proj_queries, B, -1, C, H, W)
A = dot_product_proj_queries.shape[1]
bias = dot_product_proj_tokens_bias.unsqueeze(1).repeat(1, A, 1)
# add bias (vision)
# b = self.bias.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
# tensor.repeat() is supposed to cost more memory, bias = b.repeat(B, 1, H, W)
# here we replace it with tensor.expand()
# bias = b.repeat(B, 1, H, W)
# dot_product_proj_queries = self.dot_product_projection_image(x) + bias
# print(torch.norm(dot_product_proj_tokens))
# exit()
dot_product_logit = (
torch.matmul(dot_product_proj_queries, dot_product_proj_tokens.transpose(-1, -2))
/ self.log_scale.exp()
) + bias
# dot_product_logit = (torch.matmul(dot_product_proj_queries,
# dot_product_proj_tokens.transpose(-1,
# -2)) / self.log_scale.exp()) + self.bias0
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_DOT_PRODUCT:
dot_product_logit = torch.clamp(dot_product_logit, max=50000)
dot_product_logit = torch.clamp(dot_product_logit, min=-50000)
dot_product_logits.append(dot_product_logit)
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS:
feat = feature
BF, CF, HF, WF = feat.shape
shallow_img_emb = permute_and_flatten(feat, BF, -1, CF, HF, WF)
shallow_img_emb_feats.append(shallow_img_emb)
# no matter the feature is from backboone or from fpn, we use shallow_img_embs all the time
if shallow_img_emb_feats is not None and shallow_text_emb is not None:
# shallow_img_embs = torch.cat(shallow_img_embs, dim=1)
proj_tokens = shallow_text_emb
return (
logits,
bbox_reg,
centerness,
t_logits,
proj_tokens,
contrastive_logits,
dot_product_logits,
mlm_logits,
shallow_img_emb_feats,
fused_visual_features,
)
class VLDyHeadModule(torch.nn.Module):
def __init__(self, cfg):
super(VLDyHeadModule, self).__init__()
self.cfg = cfg
self.head = VLDyHead(cfg)
box_coder = BoxCoder(cfg)
self.loss_evaluator = make_atss_loss_evaluator(cfg, box_coder)
self.box_selector_train = make_atss_postprocessor(cfg, box_coder, is_train=True)
self.box_selector_test = make_atss_postprocessor(cfg, box_coder, is_train=False)
self.anchor_generator = make_anchor_generator_complex(cfg)
self.lang_model = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE
self.joint_embedding_size = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_SIZE
self.joint_embedding_dropout = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_DROPOUT
if self.lang_model in ["bert-base-uncased", "roberta-base", "clip", "roberta-fused", "roberta-fused-v2", "roberta-fused-tiny"]:
self.lang_dim = cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM
else:
self.lang_dim = 1024
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS:
self.resizer = FeatureResizer(
input_feat_size=self.lang_dim,
output_feat_size=self.joint_embedding_size,
dropout=self.joint_embedding_dropout,
)
# if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER:
# self.tunable_linear = torch.nn.Linear(self.lang_dim, 1000, bias=False)
# self.tunable_linear.weight.data.fill_(0.0)
def forward(
self,
images,
features,
targets=None,
language_dict_features=None,
positive_map=None,
captions=None,
swint_feature_c4=None,
):
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS:
# resizer needed
embedding = language_dict_features["embedded"]
embedding = self.resizer(embedding)
elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
# no resizer needed
embedding = language_dict_features["embedded"]
# print(captions)
# print(embedding)
else:
embedding = None
if "masks" in language_dict_features:
text_masks = language_dict_features["masks"]
else:
text_masks = None
# if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER:
# embedding = self.tunable_linear.weight[:embedding.size(1), :].unsqueeze(0) + embedding
# language_dict_features['embedded'] = embedding
# language_dict_features['hidden'] = self.tunable_linear.weight[:embedding.size(1), :].unsqueeze(0) + language_dict_features['hidden']
(
box_cls,
box_regression,
centerness,
token_logits,
proj_tokens,
contrastive_logits,
dot_product_logits,
mlm_logits,
shallow_img_emb_feats,
fused_visual_features,
) = self.head(features, language_dict_features, embedding, swint_feature_c4)
anchors = self.anchor_generator(images, features)
if self.training:
return self._forward_train(
box_cls,
box_regression,
centerness,
targets,
anchors,
captions,
positive_map,
token_logits,
proj_tokens,
contrastive_logits,
dot_product_logits,
text_masks,
mlm_logits=mlm_logits,
mlm_labels=language_dict_features["mlm_labels"],
shallow_img_emb_feats=shallow_img_emb_feats,
fused_visual_features=fused_visual_features,
)
else:
return self._forward_test(
box_regression,
centerness,
anchors,
box_cls,
token_logits,
dot_product_logits,
positive_map,
fused_visual_features=fused_visual_features,
)
def _forward_train(
self,
box_cls,
box_regression,
centerness,
targets,
anchors,
captions=None,
positive_map=None,
token_logits=None,
proj_tokens=None,
contrastive_logits=None,
dot_product_logits=None,
text_masks=None,
mlm_logits=None,
mlm_labels=None,
shallow_img_emb_feats=None,
fused_visual_features=None,
):
(
loss_box_cls,
loss_box_reg,
loss_centerness,
loss_token,
loss_contrastive_align,
loss_dot_product_token,
loss_shallow_contrastive,
) = self.loss_evaluator(
box_cls,
box_regression,
centerness,
targets,
anchors,
captions,
positive_map,
token_logits,
proj_tokens,
contrastive_logits,
dot_product_logits,
text_masks,
shallow_img_emb_feats,
)
losses = {
# "loss_cls": loss_box_cls,
"loss_reg": loss_box_reg,
"loss_centerness": loss_centerness,
}
if mlm_labels is not None and mlm_logits is not None:
losses["mlm_loss"] = (
nn.CrossEntropyLoss(ignore_index=-100)(mlm_logits.view(-1, mlm_logits.size(-1)), mlm_labels.view(-1))
* self.cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS_COEF
)
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CLASSIFICATION_LOSS:
losses["loss_cls"] = loss_box_cls
else:
losses["loss_cls"] = 0.0 * loss_box_cls
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS:
losses["loss_token"] = loss_token * self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_LOSS_WEIGHT
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS:
losses["loss_contrastive_align"] = (
loss_contrastive_align * self.cfg.MODEL.DYHEAD.FUSE_CONFIG.CONTRASTIVE_ALIGN_LOSS_WEIGHT
)
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
losses["loss_dot_product_token"] = (
loss_dot_product_token * self.cfg.MODEL.DYHEAD.FUSE_CONFIG.DOT_PRODUCT_TOKEN_LOSS_WEIGHT
)
if (
self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS
or self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS
):
losses["loss_shallow_contrastive"] = (
loss_shallow_contrastive * self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_CONTRASTIVE_LOSS_WEIGHT
)
if self.cfg.MODEL.RPN_ONLY:
return None, losses, None
else:
# Let's just use one image per batch
assert (box_regression[0].shape[0]) == 1
positive_map_label_to_token = create_positive_map_label_to_token_from_positive_map(positive_map, plus=1)
boxes = self.box_selector_train(
box_regression,
centerness,
anchors,
box_cls,
token_logits,
dot_product_logits,
positive_map=positive_map_label_to_token,
)
train_boxes = []
# for b, a in zip(boxes, anchors):
# a = cat_boxlist(a)
# b.add_field("visibility", torch.ones(b.bbox.shape[0], dtype=torch.bool, device=b.bbox.device))
# del b.extra_fields['scores']
# del b.extra_fields['labels']
# train_boxes.append(cat_boxlist([b, a]))
for b, t in zip(boxes, targets):
tb = t.copy_with_fields(["labels"])
tb.add_field("scores", torch.ones(tb.bbox.shape[0], dtype=torch.bool, device=tb.bbox.device))
train_boxes.append(cat_boxlist([b, tb]))
return train_boxes, losses, fused_visual_features
def _forward_test(
self,
box_regression,
centerness,
anchors,
box_cls=None,
token_logits=None,
dot_product_logits=None,
positive_map=None,
fused_visual_features=None,
):
boxes = self.box_selector_test(
box_regression,
centerness,
anchors,
box_cls,
token_logits,
dot_product_logits,
positive_map,
)
return boxes, {}, fused_visual_features