zdou0830's picture
desco
749745d
raw
history blame
No virus
63.8 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
This file contains specific functions for computing losses on the RPN
file
"""
import torch
from torch import nn
from torch.nn import functional as F
from ..balanced_positive_negative_sampler import BalancedPositiveNegativeSampler
from ..utils import cat, concat_box_prediction_layers
from maskrcnn_benchmark.layers import smooth_l1_loss
from maskrcnn_benchmark.modeling.matcher import Matcher
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
from maskrcnn_benchmark.layers import SigmoidFocalLoss, IOULoss, TokenSigmoidFocalLoss
from maskrcnn_benchmark.utils.comm import get_world_size, reduce_sum
from maskrcnn_benchmark.utils.amp import custom_fwd, custom_bwd
from maskrcnn_benchmark.utils.shallow_contrastive_loss_helper import *
import pdb
from transformers import AutoTokenizer
INF = 1e8
class RPNLossComputation(object):
"""
This class computes the RPN loss.
"""
def __init__(self, proposal_matcher, fg_bg_sampler, box_coder):
"""
Arguments:
proposal_matcher (Matcher)
fg_bg_sampler (BalancedPositiveNegativeSampler)
box_coder (BoxCoder)
"""
# self.target_preparator = target_preparator
self.proposal_matcher = proposal_matcher
self.fg_bg_sampler = fg_bg_sampler
self.box_coder = box_coder
def match_targets_to_anchors(self, anchor, target):
match_quality_matrix = boxlist_iou(target, anchor)
matched_idxs = self.proposal_matcher(match_quality_matrix)
# RPN doesn't need any fields from target
# for creating the labels, so clear them all
target = target.copy_with_fields([])
# get the targets corresponding GT for each anchor
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
if len(target):
matched_targets = target[matched_idxs.clamp(min=0)]
else:
matched_targets = target
matched_targets.add_field("matched_idxs", matched_idxs)
return matched_targets
def prepare_targets(self, anchors, targets):
labels = []
regression_targets = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
matched_targets = self.match_targets_to_anchors(anchors_per_image, targets_per_image)
matched_idxs = matched_targets.get_field("matched_idxs")
labels_per_image = matched_idxs >= 0
labels_per_image = labels_per_image.to(dtype=torch.float32)
# discard anchors that go out of the boundaries of the image
labels_per_image[~anchors_per_image.get_field("visibility")] = -1
# discard indices that are between thresholds
inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = -1
# compute regression targets
if not matched_targets.bbox.shape[0]:
zeros = torch.zeros_like(labels_per_image)
regression_targets_per_image = torch.stack((zeros, zeros, zeros, zeros), dim=1)
else:
regression_targets_per_image = self.box_coder.encode(matched_targets.bbox, anchors_per_image.bbox)
labels.append(labels_per_image)
regression_targets.append(regression_targets_per_image)
return labels, regression_targets
@custom_fwd(cast_inputs=torch.float32)
def __call__(self, anchors, objectness, box_regression, targets):
"""
Arguments:
anchors (list[BoxList])
objectness (list[Tensor])
box_regression (list[Tensor])
targets (list[BoxList])
Returns:
objectness_loss (Tensor)
box_loss (Tensor
"""
anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors]
labels, regression_targets = self.prepare_targets(anchors, targets)
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
objectness_flattened = []
box_regression_flattened = []
# for each feature level, permute the outputs to make them be in the
# same format as the labels. Note that the labels are computed for
# all feature levels concatenated, so we keep the same representation
# for the objectness and the box_regression
for objectness_per_level, box_regression_per_level in zip(objectness, box_regression):
N, A, H, W = objectness_per_level.shape
objectness_per_level = objectness_per_level.permute(0, 2, 3, 1).reshape(N, -1)
box_regression_per_level = box_regression_per_level.view(N, -1, 4, H, W)
box_regression_per_level = box_regression_per_level.permute(0, 3, 4, 1, 2)
box_regression_per_level = box_regression_per_level.reshape(N, -1, 4)
objectness_flattened.append(objectness_per_level)
box_regression_flattened.append(box_regression_per_level)
# concatenate on the first dimension (representing the feature levels), to
# take into account the way the labels were generated (with all feature maps
# being concatenated as well)
objectness = cat(objectness_flattened, dim=1).reshape(-1)
box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4)
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)
box_loss = smooth_l1_loss(
box_regression[sampled_pos_inds],
regression_targets[sampled_pos_inds],
beta=1.0 / 9,
size_average=False,
) / (sampled_inds.numel())
objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds])
return objectness_loss, box_loss
class FocalLossComputation(object):
"""
This class computes the RetinaNet loss.
"""
def __init__(
self,
proposal_matcher,
box_coder,
generate_labels_func,
sigmoid_focal_loss,
bbox_reg_beta=0.11,
regress_norm=1.0,
):
"""
Arguments:
proposal_matcher (Matcher)
box_coder (BoxCoder)
"""
self.proposal_matcher = proposal_matcher
self.box_coder = box_coder
self.box_cls_loss_func = sigmoid_focal_loss
self.bbox_reg_beta = bbox_reg_beta
self.copied_fields = ["labels"]
self.generate_labels_func = generate_labels_func
self.discard_cases = ["between_thresholds"]
self.regress_norm = regress_norm
def match_targets_to_anchors(self, anchor, target, copied_fields=[]):
match_quality_matrix = boxlist_iou(target, anchor)
matched_idxs = self.proposal_matcher(match_quality_matrix)
# RPN doesn't need any fields from target
# for creating the labels, so clear them all
target = target.copy_with_fields(copied_fields)
# get the targets corresponding GT for each anchor
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
matched_targets = target[matched_idxs.clamp(min=0)]
matched_targets.add_field("matched_idxs", matched_idxs)
return matched_targets
def prepare_targets(self, anchors, targets):
labels = []
regression_targets = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
matched_targets = self.match_targets_to_anchors(anchors_per_image, targets_per_image, self.copied_fields)
matched_idxs = matched_targets.get_field("matched_idxs")
labels_per_image = self.generate_labels_func(matched_targets)
labels_per_image = labels_per_image.to(dtype=torch.float32)
# Background (negative examples)
bg_indices = matched_idxs == Matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_indices] = 0
# discard anchors that go out of the boundaries of the image
if "not_visibility" in self.discard_cases:
labels_per_image[~anchors_per_image.get_field("visibility")] = -1
# discard indices that are between thresholds
if "between_thresholds" in self.discard_cases:
inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = -1
# compute regression targets
regression_targets_per_image = self.box_coder.encode(matched_targets.bbox, anchors_per_image.bbox)
labels.append(labels_per_image)
regression_targets.append(regression_targets_per_image)
return labels, regression_targets
@custom_fwd(cast_inputs=torch.float32)
def __call__(self, anchors, box_cls, box_regression, targets):
"""
Arguments:
anchors (list[BoxList])
box_cls (list[Tensor])
box_regression (list[Tensor])
targets (list[BoxList])
Returns:
retinanet_cls_loss (Tensor)
retinanet_regression_loss (Tensor
"""
anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors]
labels, regression_targets = self.prepare_targets(anchors, targets)
N = len(labels)
box_cls, box_regression = concat_box_prediction_layers(box_cls, box_regression)
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)
pos_inds = torch.nonzero(labels > 0).squeeze(1)
retinanet_regression_loss = smooth_l1_loss(
box_regression[pos_inds],
regression_targets[pos_inds],
beta=self.bbox_reg_beta,
size_average=False,
) / (max(1, pos_inds.numel() * self.regress_norm))
labels = labels.int()
retinanet_cls_loss = self.box_cls_loss_func(box_cls, labels) / (pos_inds.numel() + N)
return retinanet_cls_loss, retinanet_regression_loss
class FCOSLossComputation(object):
"""
This class computes the FCOS losses.
"""
def __init__(self, cfg):
self.cls_loss_func = SigmoidFocalLoss(cfg.MODEL.FOCAL.LOSS_GAMMA, cfg.MODEL.FOCAL.LOSS_ALPHA)
self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES
self.center_sampling_radius = cfg.MODEL.FCOS.CENTER_SAMPLING_RADIUS
self.iou_loss_type = cfg.MODEL.FCOS.IOU_LOSS_TYPE
self.norm_reg_targets = cfg.MODEL.FCOS.NORM_REG_TARGETS
self.use_gt_center = cfg.MODEL.FCOS.USE_GT_CENTER
# we make use of IOU Loss for bounding boxes regression,
# but we found that L1 in log scale can yield a similar performance
self.box_reg_loss_func = IOULoss(self.iou_loss_type)
self.centerness_loss_func = torch.nn.BCEWithLogitsLoss(reduction="sum")
def get_sample_region(self, gt, strides, num_points_per, gt_xs, gt_ys, radius=1.0):
"""
This code is from
https://github.com/yqyao/FCOS_PLUS/blob/0d20ba34ccc316650d8c30febb2eb40cb6eaae37/
maskrcnn_benchmark/modeling/rpn/fcos/loss.py#L42
"""
num_gts = gt.shape[0]
K = len(gt_xs)
gt = gt[None].expand(K, num_gts, 4)
center_x = (gt[..., 0] + gt[..., 2]) / 2
center_y = (gt[..., 1] + gt[..., 3]) / 2
center_gt = gt.new_zeros(gt.shape)
# no gt
if center_x[..., 0].sum() == 0:
return gt_xs.new_zeros(gt_xs.shape, dtype=torch.uint8)
beg = 0
for level, n_p in enumerate(num_points_per):
end = beg + n_p
stride = strides[level] * radius
xmin = center_x[beg:end] - stride
ymin = center_y[beg:end] - stride
xmax = center_x[beg:end] + stride
ymax = center_y[beg:end] + stride
# limit sample region in gt
center_gt[beg:end, :, 0] = torch.where(xmin > gt[beg:end, :, 0], xmin, gt[beg:end, :, 0])
center_gt[beg:end, :, 1] = torch.where(ymin > gt[beg:end, :, 1], ymin, gt[beg:end, :, 1])
center_gt[beg:end, :, 2] = torch.where(xmax > gt[beg:end, :, 2], gt[beg:end, :, 2], xmax)
center_gt[beg:end, :, 3] = torch.where(ymax > gt[beg:end, :, 3], gt[beg:end, :, 3], ymax)
beg = end
left = gt_xs[:, None] - center_gt[..., 0]
right = center_gt[..., 2] - gt_xs[:, None]
top = gt_ys[:, None] - center_gt[..., 1]
bottom = center_gt[..., 3] - gt_ys[:, None]
center_bbox = torch.stack((left, top, right, bottom), -1)
inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
return inside_gt_bbox_mask
def prepare_targets(self, points, targets):
object_sizes_of_interest = [
[-1, 64],
[64, 128],
[128, 256],
[256, 512],
[512, INF],
]
expanded_object_sizes_of_interest = []
for l, points_per_level in enumerate(points):
object_sizes_of_interest_per_level = points_per_level.new_tensor(object_sizes_of_interest[l])
expanded_object_sizes_of_interest.append(
object_sizes_of_interest_per_level[None].expand(len(points_per_level), -1)
)
expanded_object_sizes_of_interest = torch.cat(expanded_object_sizes_of_interest, dim=0)
num_points_per_level = [len(points_per_level) for points_per_level in points]
self.num_points_per_level = num_points_per_level
points_all_level = torch.cat(points, dim=0)
labels, reg_targets = self.compute_targets_for_locations(
points_all_level, targets, expanded_object_sizes_of_interest
)
for i in range(len(labels)):
labels[i] = torch.split(labels[i], num_points_per_level, dim=0)
reg_targets[i] = torch.split(reg_targets[i], num_points_per_level, dim=0)
labels_level_first = []
reg_targets_level_first = []
for level in range(len(points)):
labels_level_first.append(torch.cat([labels_per_im[level] for labels_per_im in labels], dim=0))
reg_targets_per_level = torch.cat([reg_targets_per_im[level] for reg_targets_per_im in reg_targets], dim=0)
if self.norm_reg_targets:
reg_targets_per_level = reg_targets_per_level / self.fpn_strides[level]
reg_targets_level_first.append(reg_targets_per_level)
return labels_level_first, reg_targets_level_first
def compute_targets_for_locations(self, locations, targets, object_sizes_of_interest):
labels = []
reg_targets = []
xs, ys = locations[:, 0], locations[:, 1]
for im_i in range(len(targets)):
targets_per_im = targets[im_i]
assert targets_per_im.mode == "xyxy"
if self.use_gt_center:
center = targets_per_im.get_field("cbox")
bboxes = center.bbox
area = center.area()
else:
bboxes = targets_per_im.bbox
area = targets_per_im.area()
labels_per_im = targets_per_im.get_field("labels")
l = xs[:, None] - bboxes[:, 0][None]
t = ys[:, None] - bboxes[:, 1][None]
r = bboxes[:, 2][None] - xs[:, None]
b = bboxes[:, 3][None] - ys[:, None]
reg_targets_per_im = torch.stack([l, t, r, b], dim=2)
if self.center_sampling_radius > 0:
is_in_boxes = self.get_sample_region(
bboxes, self.fpn_strides, self.num_points_per_level, xs, ys, radius=self.center_sampling_radius
)
else:
# no center sampling, it will use all the locations within a ground-truth box
is_in_boxes = reg_targets_per_im.min(dim=2)[0] > 0
max_reg_targets_per_im = reg_targets_per_im.max(dim=2)[0]
# limit the regression range for each location
is_cared_in_the_level = (max_reg_targets_per_im >= object_sizes_of_interest[:, [0]]) & (
max_reg_targets_per_im <= object_sizes_of_interest[:, [1]]
)
locations_to_gt_area = area[None].repeat(len(locations), 1)
locations_to_gt_area[is_in_boxes == 0] = INF
locations_to_gt_area[is_cared_in_the_level == 0] = INF
# if there are still more than one objects for a location,
# we choose the one with minimal area
locations_to_min_area, locations_to_gt_inds = locations_to_gt_area.min(dim=1)
reg_targets_per_im = reg_targets_per_im[range(len(locations)), locations_to_gt_inds]
labels_per_im = labels_per_im[locations_to_gt_inds]
labels_per_im[locations_to_min_area == INF] = 0
labels.append(labels_per_im)
reg_targets.append(reg_targets_per_im)
return labels, reg_targets
def compute_centerness_targets(self, reg_targets):
left_right = reg_targets[:, [0, 2]]
top_bottom = reg_targets[:, [1, 3]]
centerness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]
)
return torch.sqrt(centerness)
@custom_fwd(cast_inputs=torch.float32)
def __call__(self, locations, box_cls, box_regression, centerness, targets):
"""
Arguments:
locations (list[BoxList])
box_cls (list[Tensor])
box_regression (list[Tensor])
centerness (list[Tensor])
targets (list[BoxList])
Returns:
cls_loss (Tensor)
reg_loss (Tensor)
centerness_loss (Tensor)
"""
N = box_cls[0].size(0)
num_classes = box_cls[0].size(1)
labels, reg_targets = self.prepare_targets(locations, targets)
box_cls_flatten = []
box_regression_flatten = []
centerness_flatten = []
labels_flatten = []
reg_targets_flatten = []
for l in range(len(labels)):
box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape(-1, num_classes))
box_regression_flatten.append(box_regression[l].permute(0, 2, 3, 1).reshape(-1, 4))
labels_flatten.append(labels[l].reshape(-1))
reg_targets_flatten.append(reg_targets[l].reshape(-1, 4))
centerness_flatten.append(centerness[l].reshape(-1))
box_cls_flatten = torch.cat(box_cls_flatten, dim=0)
box_regression_flatten = torch.cat(box_regression_flatten, dim=0)
centerness_flatten = torch.cat(centerness_flatten, dim=0)
labels_flatten = torch.cat(labels_flatten, dim=0)
reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0)
pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1)
box_regression_flatten = box_regression_flatten[pos_inds]
reg_targets_flatten = reg_targets_flatten[pos_inds]
centerness_flatten = centerness_flatten[pos_inds]
cls_loss = self.cls_loss_func(box_cls_flatten, labels_flatten.int()) / max(pos_inds.numel(), 1.0)
if pos_inds.numel() > 0:
centerness_targets = self.compute_centerness_targets(reg_targets_flatten)
reg_loss = (
self.box_reg_loss_func(box_regression_flatten, reg_targets_flatten, centerness_targets)
/ centerness_targets.sum()
)
centerness_loss = self.centerness_loss_func(centerness_flatten, centerness_targets) / max(
pos_inds.numel(), 1.0
)
else:
reg_loss = box_regression_flatten.sum()
centerness_loss = centerness_flatten.sum()
return cls_loss, reg_loss, centerness_loss
# class ATSSLossComputation(object):
class ATSSLossComputation(torch.nn.Module):
def __init__(self, cfg, box_coder):
super(ATSSLossComputation, self).__init__()
self.cfg = cfg
self.cls_loss_func = SigmoidFocalLoss(cfg.MODEL.FOCAL.LOSS_GAMMA, cfg.MODEL.FOCAL.LOSS_ALPHA)
self.centerness_loss_func = torch.nn.BCEWithLogitsLoss(reduction="sum")
self.matcher = Matcher(cfg.MODEL.FOCAL.FG_IOU_THRESHOLD, cfg.MODEL.FOCAL.BG_IOU_THRESHOLD, True)
self.box_coder = box_coder
if (
self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS
or self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS
):
self.token_loss_func = TokenSigmoidFocalLoss(
cfg.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_ALPHA, cfg.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_GAMMA
)
if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE in ["roberta-fused", "roberta-fused-v2", "roberta-fused-tiny"]:
self.lang = "roberta-base"
else:
self.lang = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE
# self.tokenizer = AutoTokenizer.from_pretrained(self.lang)
if self.cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip":
from transformers import CLIPTokenizerFast
# self.tokenizer = build_tokenizer(self.cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
print("Reuse token 'ðŁĴij</w>' (token_id = 49404) for mask token!")
self.tokenizer = CLIPTokenizerFast.from_pretrained(
"openai/clip-vit-base-patch32", from_slow=True, mask_token="ðŁĴij</w>"
)
else:
self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True)
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.lang)
# if use shallow contrastive loss
if (
self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS
or self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS
):
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS:
assert self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS == False
channels = cfg.MODEL.DYHEAD.CHANNELS
num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE
shallow_input_dim = channels * num_anchors
elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS:
assert self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS == False
shallow_input_dim = cfg.MODEL.SWINT.OUT_CHANNELS[-2]
shallow_log_scale = self.cfg.MODEL.DYHEAD.SHALLOW_LOG_SCALE
shallow_contrastive_hdim = cfg.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_CONTRASTIVE_HIDDEN_DIM
# self.shallow_contrastive_projection_image = nn.Conv2d(channels, num_anchors * shallow_contrastive_hdim,
# kernel_size=1)
self.shallow_contrastive_projection_image = nn.Linear(
shallow_input_dim, shallow_contrastive_hdim, bias=True
)
self.shallow_contrastive_projection_text = nn.Linear(
self.cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, shallow_contrastive_hdim, bias=True
)
self.shallow_log_scale = nn.Parameter(torch.Tensor([shallow_log_scale]), requires_grad=True)
# (initialization) if use shallow contrastive loss
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS:
for modules in [self.shallow_contrastive_projection_image, self.shallow_contrastive_projection_text]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, 0)
if isinstance(l, nn.Linear):
torch.nn.init.xavier_uniform_(l.weight)
l.bias.data.fill_(0)
def NllSoftMaxLoss(self, logits, target):
loss_ce = -target * logits.log_softmax(
-1
) # basically, only the those positives with positive target_sim will have losses
return loss_ce
def ContrastiveAlignLoss(self, logits, positive_map):
positive_logits = -logits.masked_fill(~positive_map, 0)
negative_logits = logits # .masked_fill(positive_map, -1000000)
boxes_with_pos = positive_map.any(2)
pos_term = positive_logits.sum(2)
neg_term = negative_logits.logsumexp(2)
nb_pos = positive_map.sum(2) + 1e-6
box_to_token_loss = ((pos_term / nb_pos + neg_term)).masked_fill(~boxes_with_pos, 0).sum()
tokens_with_pos = positive_map.any(1)
pos_term = positive_logits.sum(1)
neg_term = negative_logits.logsumexp(1)
nb_pos = positive_map.sum(1) + 1e-6
tokens_to_boxes_loss = ((pos_term / nb_pos + neg_term)).masked_fill(~tokens_with_pos, 0).sum()
tot_loss = (box_to_token_loss + tokens_to_boxes_loss) / 2
return tot_loss
def GIoULoss(self, pred, target, anchor, weight=None):
pred_boxes = self.box_coder.decode(pred.view(-1, 4), anchor.view(-1, 4))
pred_x1 = pred_boxes[:, 0]
pred_y1 = pred_boxes[:, 1]
pred_x2 = pred_boxes[:, 2]
pred_y2 = pred_boxes[:, 3]
pred_x2 = torch.max(pred_x1, pred_x2)
pred_y2 = torch.max(pred_y1, pred_y2)
pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
gt_boxes = self.box_coder.decode(target.view(-1, 4), anchor.view(-1, 4))
target_x1 = gt_boxes[:, 0]
target_y1 = gt_boxes[:, 1]
target_x2 = gt_boxes[:, 2]
target_y2 = gt_boxes[:, 3]
target_area = (target_x2 - target_x1) * (target_y2 - target_y1)
x1_intersect = torch.max(pred_x1, target_x1)
y1_intersect = torch.max(pred_y1, target_y1)
x2_intersect = torch.min(pred_x2, target_x2)
y2_intersect = torch.min(pred_y2, target_y2)
area_intersect = torch.zeros(pred_x1.size()).to(pred)
mask = (y2_intersect > y1_intersect) * (x2_intersect > x1_intersect)
area_intersect[mask] = (x2_intersect[mask] - x1_intersect[mask]) * (y2_intersect[mask] - y1_intersect[mask])
x1_enclosing = torch.min(pred_x1, target_x1)
y1_enclosing = torch.min(pred_y1, target_y1)
x2_enclosing = torch.max(pred_x2, target_x2)
y2_enclosing = torch.max(pred_y2, target_y2)
area_enclosing = (x2_enclosing - x1_enclosing) * (y2_enclosing - y1_enclosing) + 1e-7
area_union = pred_area + target_area - area_intersect + 1e-7
ious = area_intersect / area_union
gious = ious - (area_enclosing - area_union) / area_enclosing
losses = 1 - gious
if weight is not None and weight.sum() > 0:
return (losses * weight).sum()
else:
assert losses.numel() != 0
return losses.sum()
def prepare_targets(self, targets, anchors, tokenized=None, positive_map=None, proj_tokens=None):
cls_labels = []
reg_targets = []
token_labels = []
map_labels = []
gold_box_od_labels = []
od_label_of_tokens_labels = []
positive_indices = []
offset = 0
for im_i in range(len(targets)):
targets_per_im = targets[im_i]
assert targets_per_im.mode == "xyxy"
# bboxes_per_im = targets_per_im.get_field("boxes")
bboxes_per_im = targets_per_im.bbox
labels_per_im = targets_per_im.get_field("labels")
num_gt = len(bboxes_per_im)
if positive_map is not None:
token_per_im = positive_map[offset : offset + num_gt, :]
offset += num_gt
# shallow contrastive
if "original_od_label" in targets_per_im.fields():
gold_box_od_label = targets_per_im.get_field("original_od_label")
if "positive_map_for_od_labels" in targets_per_im.fields():
od_label_of_token_per_im = targets_per_im.get_field("positive_map_for_od_labels")
# print(gold_box_od_label)
# print(od_label_of_token_per_im)
if positive_map is not None and proj_tokens is not None:
if "tokens_positive" in targets_per_im.fields():
cur_tokens = targets_per_im.get_field("tokens_positive")
else:
cur_tokens = targets_per_im.get_field("tokens")
map = torch.zeros((len(cur_tokens), proj_tokens.shape[1]), dtype=torch.bool)
for j, tok_list in enumerate(cur_tokens):
for (beg, end) in tok_list:
beg_pos = tokenized.char_to_token(im_i, beg)
end_pos = tokenized.char_to_token(im_i, end - 1)
if beg_pos is None:
try:
beg_pos = tokenized.char_to_token(im_i, beg + 1)
if beg_pos is None:
beg_pos = tokenized.char_to_token(im_i, beg + 2)
except:
beg_pos = None
if end_pos is None:
try:
end_pos = tokenized.char_to_token(im_i, end - 2)
if end_pos is None:
end_pos = tokenized.char_to_token(im_i, end - 3)
except:
end_pos = None
if beg_pos is None or end_pos is None:
continue
assert beg_pos is not None and end_pos is not None
map[j, beg_pos : end_pos + 1].fill_(True)
anchors_per_im = cat_boxlist(anchors[im_i])
num_anchors_per_loc = len(self.cfg.MODEL.RPN.ASPECT_RATIOS) * self.cfg.MODEL.RPN.SCALES_PER_OCTAVE
num_anchors_per_level = [len(anchors_per_level.bbox) for anchors_per_level in anchors[im_i]]
ious = boxlist_iou(anchors_per_im, targets_per_im)
gt_cx = (bboxes_per_im[:, 2] + bboxes_per_im[:, 0]) / 2.0
gt_cy = (bboxes_per_im[:, 3] + bboxes_per_im[:, 1]) / 2.0
gt_points = torch.stack((gt_cx, gt_cy), dim=1)
anchors_cx_per_im = (anchors_per_im.bbox[:, 2] + anchors_per_im.bbox[:, 0]) / 2.0
anchors_cy_per_im = (anchors_per_im.bbox[:, 3] + anchors_per_im.bbox[:, 1]) / 2.0
anchor_points = torch.stack((anchors_cx_per_im, anchors_cy_per_im), dim=1)
distances = (anchor_points[:, None, :] - gt_points[None, :, :]).pow(2).sum(-1).sqrt()
# Selecting candidates based on the center distance between anchor box and object
candidate_idxs = []
star_idx = 0
for level, anchors_per_level in enumerate(anchors[im_i]):
end_idx = star_idx + num_anchors_per_level[level]
distances_per_level = distances[star_idx:end_idx, :]
topk = min(self.cfg.MODEL.ATSS.TOPK * num_anchors_per_loc, num_anchors_per_level[level])
_, topk_idxs_per_level = distances_per_level.topk(topk, dim=0, largest=False)
candidate_idxs.append(topk_idxs_per_level + star_idx)
star_idx = end_idx
candidate_idxs = torch.cat(candidate_idxs, dim=0)
# Using the sum of mean and standard deviation as the IoU threshold to select final positive samples
candidate_ious = ious[candidate_idxs, torch.arange(num_gt)]
iou_mean_per_gt = candidate_ious.mean(0)
iou_std_per_gt = candidate_ious.std(0)
iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt
is_pos = candidate_ious >= iou_thresh_per_gt[None, :]
# Limiting the final positive samples’ center to object
anchor_num = anchors_cx_per_im.shape[0]
for ng in range(num_gt):
candidate_idxs[:, ng] += ng * anchor_num
e_anchors_cx = anchors_cx_per_im.view(1, -1).expand(num_gt, anchor_num).contiguous().view(-1)
e_anchors_cy = anchors_cy_per_im.view(1, -1).expand(num_gt, anchor_num).contiguous().view(-1)
candidate_idxs = candidate_idxs.view(-1)
if num_gt == 0:
l = e_anchors_cx[candidate_idxs] - bboxes_per_im[:, 0]
t = e_anchors_cy[candidate_idxs] - bboxes_per_im[:, 1]
r = bboxes_per_im[:, 2] - e_anchors_cx[candidate_idxs]
b = bboxes_per_im[:, 3] - e_anchors_cy[candidate_idxs]
else:
l = e_anchors_cx[candidate_idxs].view(-1, num_gt) - bboxes_per_im[:, 0]
t = e_anchors_cy[candidate_idxs].view(-1, num_gt) - bboxes_per_im[:, 1]
r = bboxes_per_im[:, 2] - e_anchors_cx[candidate_idxs].view(-1, num_gt)
b = bboxes_per_im[:, 3] - e_anchors_cy[candidate_idxs].view(-1, num_gt)
is_in_gts = torch.stack([l, t, r, b], dim=1).min(dim=1)[0] > 0.01
is_pos = is_pos & is_in_gts
# if an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.
ious_inf = torch.full_like(ious, -INF).t().contiguous().view(-1)
index = candidate_idxs.view(-1)[is_pos.view(-1)]
ious_inf[index] = ious.t().contiguous().view(-1)[index]
if num_gt > 0:
ious_inf = ious_inf.view(num_gt, -1).t()
anchors_to_gt_values, anchors_to_gt_indexs = ious_inf.max(dim=1)
# get positive anchors index from ATSS
positive_index = [i[0].item() for i in torch.nonzero(anchors_to_gt_indexs)]
cls_labels_per_im = labels_per_im[anchors_to_gt_indexs]
cls_labels_per_im[anchors_to_gt_values == -INF] = 0
else:
cls_labels_per_im = torch.zeros((ious.size(0)), device=labels_per_im.device)
anchors_to_gt_values, anchors_to_gt_indexs = [], []
if positive_map is not None:
if num_gt > 0:
token_labels_per_im = token_per_im[anchors_to_gt_indexs]
unmatched_labels = torch.zeros(token_labels_per_im.shape[1], device=token_labels_per_im.device)
if not self.cfg.MODEL.DYHEAD.FUSE_CONFIG.MUTE_NOOBJ_TOKEN:
unmatched_labels[-1] = 1 # token: none object - > 256
token_labels_per_im[anchors_to_gt_values == -INF] = unmatched_labels
# move from cpu to gpu
token_labels_per_im = token_labels_per_im.to(cls_labels_per_im.device)
else:
unmatched_labels = torch.zeros(token_per_im.size(1), device=token_per_im.device)
if not self.cfg.MODEL.DYHEAD.FUSE_CONFIG.MUTE_NOOBJ_TOKEN:
unmatched_labels[-1] = 1 # token: none object - > 256
token_labels_per_im = unmatched_labels.unsqueeze(0).repeat(ious.size(0), 1)
token_labels_per_im = token_labels_per_im.to(cls_labels_per_im.device)
if positive_map is not None and proj_tokens is not None: # TODO fix this for no box case
map_labels_per_im = map[anchors_to_gt_indexs]
unmatched_labels = torch.zeros(
map_labels_per_im.shape[1], dtype=torch.bool, device=map_labels_per_im.device
) # map: none False
map_labels_per_im[anchors_to_gt_values == -INF] = unmatched_labels
# move from cpu to gpu
map_labels_per_im = map_labels_per_im.to(cls_labels_per_im.device)
# print(map_labels_per_im[anchors_to_gt_values == -INF].shape)
# print(map_labels_per_im[anchors_to_gt_values != -INF][0])
if positive_map is not None and proj_tokens is not None:
gold_box_od_label_per_im = gold_box_od_label[anchors_to_gt_indexs]
gold_box_od_label_per_im[anchors_to_gt_values == -INF] = -100
# move from cpu to gpu
gold_box_od_label_per_im = gold_box_od_label_per_im.to(cls_labels_per_im.device)
# print(gold_box_od_label_per_im[anchors_to_gt_values != -INF])
matched_gts = bboxes_per_im[anchors_to_gt_indexs]
reg_targets_per_im = self.box_coder.encode(matched_gts, anchors_per_im.bbox)
cls_labels.append(cls_labels_per_im)
reg_targets.append(reg_targets_per_im)
if positive_map is not None:
token_labels.append(token_labels_per_im)
if positive_map is not None and proj_tokens is not None:
map_labels.append(map_labels_per_im)
gold_box_od_labels.append(gold_box_od_label_per_im)
od_label_of_tokens_labels.append(od_label_of_token_per_im)
positive_indices.append(positive_index)
# print([len(x) for x in positive_indices])
return (
cls_labels,
reg_targets,
token_labels,
map_labels,
gold_box_od_labels,
od_label_of_tokens_labels,
positive_indices,
)
def compute_centerness_targets(self, reg_targets, anchors):
gts = self.box_coder.decode(reg_targets, anchors)
anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
l = anchors_cx - gts[:, 0]
t = anchors_cy - gts[:, 1]
r = gts[:, 2] - anchors_cx
b = gts[:, 3] - anchors_cy
left_right = torch.stack([l, r], dim=1)
top_bottom = torch.stack([t, b], dim=1)
centerness = torch.sqrt(
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
* (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
)
assert not torch.isnan(centerness).any()
return centerness
@custom_fwd(cast_inputs=torch.float32)
def __call__(
self,
box_cls,
box_regression,
centerness,
targets,
anchors,
captions=None,
positive_map=None,
token_logits=None,
proj_tokens=None,
contrastive_logits=None,
dot_product_logits=None,
text_masks=None,
shallow_img_emb_feats=None,
):
tokenized = None
if captions is not None:
# tokenized = self.tokenizer.batch_encode_plus(captions, padding="longest", return_tensors="pt")
if self.cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip":
tokenized = self.tokenizer.batch_encode_plus(
captions,
max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN,
padding="max_length" if self.cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX else "longest",
return_tensors="pt",
truncation=True,
)
else:
tokenized = self.tokenizer.batch_encode_plus(captions, padding="longest", return_tensors="pt")
(
labels,
reg_targets,
token_labels,
map_labels,
gold_box_od_labels,
od_label_of_tokens_labels,
positive_indices,
) = self.prepare_targets(targets, anchors, tokenized, positive_map, proj_tokens)
N = len(labels)
box_regression_flatten, box_cls_flatten, token_logits_stacked = concat_box_prediction_layers(
box_regression,
box_cls,
token_logits,
)
# contrastive logits # TODO: fix no box case here
if positive_map is not None and contrastive_logits is not None:
contrastive_logits = torch.cat(contrastive_logits, dim=1)
# dot product soft token logits
if dot_product_logits is not None:
dot_product_logits = torch.cat(dot_product_logits, dim=1)
centerness_flatten = [ct.permute(0, 2, 3, 1).reshape(N, -1, 1) for ct in centerness]
centerness_flatten = torch.cat(centerness_flatten, dim=1).reshape(-1)
labels_flatten = torch.cat(labels, dim=0)
reg_targets_flatten = torch.cat(reg_targets, dim=0)
anchors_flatten = torch.cat([cat_boxlist(anchors_per_image).bbox for anchors_per_image in anchors], dim=0)
if positive_map is not None:
token_labels_stacked = torch.stack(token_labels, dim=0)
if positive_map is not None and proj_tokens is not None: # TODO: fix no box case here
proj_map = torch.stack(map_labels, dim=0)
if positive_map is not None and proj_tokens is not None:
positive_map_box_to_self_text = None
shallow_positive_map = None
bs = proj_tokens.shape[0]
device = proj_tokens.device
# NOTE: 0. setup env
if dist.is_dist_avail_and_initialized():
world_size = dist.get_world_size()
rank = torch.distributed.get_rank()
else:
world_size = 1
rank = 0
if contrastive_logits is not None:
positive_map_box_to_self_text = torch.stack(map_labels, dim=0)
if shallow_img_emb_feats is not None:
"""
Ultimate:
N*B*(max_anchor_num) x N*B*T
Final Goal:
F = B x (max_anchor_num) x N*B*T
X: B x (max_anchor_num) od_labels : [0, 20, 30, ..]
Y: N*B*T: which denotes the od_label of every token
F[i,j] = A[i] == B[j]
"""
with torch.no_grad():
# NOTE: 1. get X (predicted_box_od_label), which the detection label of every predicted boxes
# predicted_box_od_label: B x A
# check memory limitation: prevent # of positive >= # of max_positive
new_positive_indices = []
# print([len(positive_index) for positive_index in positive_indices])
for positive_index in positive_indices:
if len(positive_index) >= self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_MAX_POSITIVE_ANCHORS:
import random
positive_index = sorted(
random.sample(
positive_index, self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_MAX_POSITIVE_ANCHORS
)
)
new_positive_indices.append(positive_index)
# print([len(positive_index) for positive_index in positive_indices])
max_len = max([len(positive_index) for positive_index in new_positive_indices])
max_anchor_num = max_len
if world_size > 1:
num_anchors = torch.tensor(max_len, device=positive_map.device)
num_anchors_full = [torch.zeros_like(num_anchors) for _ in range(world_size)]
torch.distributed.all_gather(num_anchors_full, num_anchors)
max_anchor_num = max([anchor.item() for anchor in num_anchors_full])
new_negative_pad_indices = []
# if not PAD_ZEROS, select random negative paddings
if not self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_ZERO_PADS:
for (positive_index, old_positive_index) in zip(new_positive_indices, positive_indices):
negative_index = [
i for i in range(len(cat_boxlist(anchors[0]))) if i not in old_positive_index
]
import random
negative_pad_index = sorted(
random.sample(negative_index, max_anchor_num - len(positive_index))
)
new_negative_pad_indices.append(negative_pad_index)
predicted_box_od_label = []
for i in range(bs):
predicted_box_od_label.append(
pad_tensor_given_dim_length(
gold_box_od_labels[i][new_positive_indices[i]],
dim=0,
length=max_anchor_num,
padding_value=-100,
batch_first=False,
)
)
predicted_box_od_label = torch.stack(predicted_box_od_label, dim=0)
# if padding, need to create image masks to filter out the paddings
image_masks = None
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_ZERO_PADS:
image_masks = torch.zeros((bs, max_anchor_num), dtype=torch.long).to(text_masks.device)
for i in range(bs):
image_masks[i, : len(new_positive_indices[i])] = 1
# NOTE: 2. Get Y (od_label_of_tokens)
# od_label_of_tokens: N x B x T
od_label_of_tokens = torch.stack(od_label_of_tokens_labels, dim=0).long()
od_label_of_tokens = gather_tensors(od_label_of_tokens)
# NOTE: 3. get F
# F: B*A x N*B*T
mapping_predicted_box_to_all_text = predicted_box_od_label.view(-1).unsqueeze(
1
) == od_label_of_tokens.view(-1).unsqueeze(0)
# NOTE: 4. we still need to calculate the mapping between predicted box to its corresponding text's mapping
# positive_map_box_to_self_text: B x A x T, leave this for vanilla contrastive alignment loss
positive_map_box_to_self_text = []
for i in range(bs):
positive_map_box_to_self_text.append(
pad_tensor_given_dim_length(
map_labels[i][new_positive_indices[i]],
dim=0,
length=max_anchor_num,
padding_value=False,
batch_first=False,
)
)
positive_map_box_to_self_text = torch.stack(positive_map_box_to_self_text, dim=0)
# change the corresponding place in our batch
for i in range(bs):
mapping_predicted_box_to_all_text[
i * max_anchor_num : (i + 1) * max_anchor_num,
(rank * bs + i) * 256 : (rank * bs + i + 1) * 256,
] = positive_map_box_to_self_text[i]
# NOTE: 5. communicate and get positive map
# mapping_predicted_box_to_all_text: N*B*A x N*B*T
mapping_predicted_box_to_all_text = gather_tensors(mapping_predicted_box_to_all_text).view(
-1, mapping_predicted_box_to_all_text.size(-1)
)
shallow_positive_map = mapping_predicted_box_to_all_text # This is the true positive map
shallow_positive_map = shallow_positive_map.unsqueeze(0)
# Get text attention masks
text_attention_mask = torch.zeros((bs, 256), dtype=torch.long) # B x 256
for i in range(bs):
text_attention_mask[i, : len(text_masks[i])] = text_masks[i]
text_attention_mask = gather_tensors(text_attention_mask.bool().to(device)) # N x B x 256
# if PAD_ZEROS, get image masks
if image_masks is not None:
image_attention_mask = torch.zeros((bs, max_anchor_num), dtype=torch.long) # B x max_anchor
for i in range(bs):
image_attention_mask[i, : len(image_masks[i])] = image_masks[i]
image_attention_mask = gather_tensors(
image_attention_mask.bool().to(device)
) # N x B x max_anchor
# NOTE: 6. calculate shallow contrastive logits
shallow_proj_tokens = F.normalize(self.shallow_contrastive_projection_text(proj_tokens), p=2, dim=-1)
shallow_normalized_img_embs = []
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS:
# choice 1:use features from SWINT backbone layer (c4) before vl fusion
from maskrcnn_benchmark.layers.roi_align import ROIAlignV2
pooler = ROIAlignV2((1, 1), 1.0 / 16, 0)
# get positive features
for i in range(bs):
rois = convert_to_roi_format(cat_boxlist(anchors[i])[new_positive_indices[i]])
roi_feature = pooler(shallow_img_emb_feats[i].unsqueeze(0), rois)
roi_feature = roi_feature.squeeze(-1).squeeze(-1)
shallow_contrastive_proj_queries = self.shallow_contrastive_projection_image(roi_feature)
shallow_normalized_img_emb = F.normalize(shallow_contrastive_proj_queries, p=2, dim=-1)
if image_masks is not None:
# pad zeros
shallow_normalized_img_embs.append(
pad_tensor_given_dim_length(
shallow_normalized_img_emb,
dim=0,
length=max_anchor_num,
padding_value=0.0,
batch_first=False,
)
)
else:
# pad negatives
negative_rois = convert_to_roi_format(cat_boxlist(anchors[i])[new_negative_pad_indices[i]])
negative_roi_feature = pooler(shallow_img_emb_feats[i].unsqueeze(0), negative_rois)
negative_roi_feature = negative_roi_feature.squeeze(-1).squeeze(-1)
negative_shallow_contrastive_proj_queries = self.shallow_contrastive_projection_image(
negative_roi_feature
)
negative_shallow_normalized_img_emb = F.normalize(
negative_shallow_contrastive_proj_queries, p=2, dim=-1
)
shallow_normalized_img_embs.append(
pad_random_negative_tensor_given_length(
shallow_normalized_img_emb,
negative_shallow_normalized_img_emb,
length=max_anchor_num,
)
)
elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS:
# choice 2:use features after FPN
shallow_img_embs = torch.cat(shallow_img_emb_feats, dim=1)
# get positive features
for i in range(bs):
shallow_contrastive_proj_queries = self.shallow_contrastive_projection_image(
shallow_img_embs[i, new_positive_indices[i], :]
)
shallow_normalized_img_emb = F.normalize(shallow_contrastive_proj_queries, p=2, dim=-1)
if image_masks is not None:
# pad zeros
shallow_normalized_img_embs.append(
pad_tensor_given_dim_length(
shallow_normalized_img_emb,
dim=0,
length=max_anchor_num,
padding_value=0.0,
batch_first=False,
)
)
else:
# pad negatives
negative_shallow_contrastive_proj_queries = self.shallow_contrastive_projection_image(
shallow_img_embs[i, new_negative_pad_indices[i], :]
)
negative_shallow_normalized_img_emb = F.normalize(
negative_shallow_contrastive_proj_queries, p=2, dim=-1
)
shallow_normalized_img_embs.append(
pad_random_negative_tensor_given_length(
shallow_normalized_img_emb,
negative_shallow_normalized_img_emb,
length=max_anchor_num,
)
)
shallow_normalized_img_embs = torch.stack(shallow_normalized_img_embs, dim=0)
shallow_normalized_text_emb = shallow_proj_tokens
shallow_normalized_text_emb = pad_tensor_given_dim_length(
shallow_normalized_text_emb, dim=1, length=256, padding_value=0.0
)
gathered_shallow_normalized_img_emb = gather_tensors(shallow_normalized_img_embs)
gathered_shallow_normalized_text_emb = gather_tensors(shallow_normalized_text_emb)
gathered_shallow_normalized_img_emb = gathered_shallow_normalized_img_emb.view(
-1, gathered_shallow_normalized_img_emb.size(-1)
)
gathered_shallow_normalized_text_emb = gathered_shallow_normalized_text_emb.view(
-1, gathered_shallow_normalized_text_emb.size(-1)
)
shallow_contrastive_logits = (
torch.matmul(
gathered_shallow_normalized_img_emb, gathered_shallow_normalized_text_emb.transpose(-1, -2)
)
/ self.shallow_log_scale.exp()
)
shallow_contrastive_logits = shallow_contrastive_logits.unsqueeze(0)
# apply text mask
text_attention_mask = text_attention_mask.view(-1).unsqueeze(0).unsqueeze(0)
text_attention_mask = text_attention_mask.repeat(
1, shallow_contrastive_logits.size(1), 1
) # copy along the image feature dimension
shallow_contrastive_logits = shallow_contrastive_logits.masked_fill(~text_attention_mask, -1000000)
# if PAD ZEROS, apply image mask
if image_masks is not None:
image_attention_mask = image_attention_mask.view(-1).unsqueeze(0).unsqueeze(-1)
image_attention_mask = image_attention_mask.repeat(
1, 1, shallow_contrastive_logits.size(2)
) # copy along the text feature dimension
shallow_contrastive_logits = shallow_contrastive_logits.masked_fill(~image_attention_mask, -1000000)
# Note: 7. calculate image and text logits and maps
shallow_image_logits = shallow_contrastive_logits[
:, (rank * bs) * max_anchor_num : (rank * bs + bs) * max_anchor_num, :
]
shallow_image_positive_map = normalized_positive_map(
shallow_positive_map[:, (rank * bs) * max_anchor_num : (rank * bs + bs) * max_anchor_num, :]
)
shallow_text_logits = shallow_contrastive_logits[
:, :, (rank * bs) * 256 : (rank * bs + bs) * 256
].transpose(1, 2)
shallow_text_positive_map = normalized_positive_map(
shallow_positive_map[:, :, (rank * bs) * 256 : (rank * bs + bs) * 256].transpose(1, 2)
)
pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1)
num_gpus = get_world_size()
total_num_pos = reduce_sum(pos_inds.new_tensor([pos_inds.numel()])).item()
num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0)
# TODO: Do we want to have this loss for the no box case? currently all labels are set to 0 which might not be the current way forward
# I set it to zero since thats what is chosen in prepare_targets : cls_labels_per_im[anchors_to_gt_values == -INF] = 0
cls_loss = self.cls_loss_func(box_cls_flatten, labels_flatten.int()) / num_pos_avg_per_gpu
token_logits_loss = None
contrastive_align_loss = None
dot_product_token_loss = None
shallow_contrastive_loss = None
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.MUTE_NON_ESSENTIAL_TOKENS:
# we mask out the non_essential tokens
greenlight_map = [i.get_field("greenlight_map") for i in targets]
greenlight_map = torch.stack(greenlight_map, dim=0)
# make sure the greenlight map is the same size as the text masks
assert(greenlight_map.size(0) == text_masks.size(0))
assert(greenlight_map.size(1) == text_masks.size(1))
_text_masks_for_loss = greenlight_map
else:
_text_masks_for_loss = text_masks
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS:
token_logits_loss = (
self.token_loss_func(
token_logits_stacked, token_labels_stacked, text_masks=text_masks, version="binary"
)
/ num_pos_avg_per_gpu
)
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS:
contrastive_align_loss = (
self.ContrastiveAlignLoss(contrastive_logits, positive_map_box_to_self_text) / num_pos_avg_per_gpu
)
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
dot_product_token_loss = (
self.token_loss_func(dot_product_logits, token_labels_stacked, text_masks=_text_masks_for_loss, version="binary")
/ num_pos_avg_per_gpu
)
if (
self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS
or self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS
):
box_to_token_loss = self.NllSoftMaxLoss(shallow_image_logits, shallow_image_positive_map).sum()
token_to_box_loss = self.NllSoftMaxLoss(shallow_text_logits, shallow_text_positive_map).sum()
tot_loss = (box_to_token_loss + token_to_box_loss) / 2
shallow_contrastive_loss = tot_loss / num_pos_avg_per_gpu
box_regression_flatten = box_regression_flatten[pos_inds]
reg_targets_flatten = reg_targets_flatten[pos_inds]
anchors_flatten = anchors_flatten[pos_inds]
centerness_flatten = centerness_flatten[pos_inds]
if pos_inds.numel() > 0:
centerness_targets = self.compute_centerness_targets(reg_targets_flatten, anchors_flatten)
sum_centerness_targets_avg_per_gpu = reduce_sum(centerness_targets.sum()).item() / float(num_gpus)
reg_loss = (
self.GIoULoss(box_regression_flatten, reg_targets_flatten, anchors_flatten, weight=centerness_targets)
/ sum_centerness_targets_avg_per_gpu
)
centerness_loss = self.centerness_loss_func(centerness_flatten, centerness_targets) / num_pos_avg_per_gpu
else:
reg_loss = box_regression_flatten.sum()
reduce_sum(centerness_flatten.new_tensor([0.0]))
centerness_loss = centerness_flatten.sum()
return (
cls_loss,
reg_loss * self.cfg.MODEL.ATSS.REG_LOSS_WEIGHT,
centerness_loss,
token_logits_loss,
contrastive_align_loss,
dot_product_token_loss,
shallow_contrastive_loss,
)
def generate_anchor_labels(matched_targets):
labels_per_image = matched_targets.get_field("labels")
return labels_per_image
def make_focal_loss_evaluator(cfg, box_coder):
matcher = Matcher(
cfg.MODEL.FOCAL.FG_IOU_THRESHOLD,
cfg.MODEL.FOCAL.BG_IOU_THRESHOLD,
allow_low_quality_matches=True,
)
sigmoid_focal_loss = SigmoidFocalLoss(cfg.MODEL.FOCAL.LOSS_GAMMA, cfg.MODEL.FOCAL.LOSS_ALPHA)
loss_evaluator = FocalLossComputation(
matcher,
box_coder,
generate_anchor_labels,
sigmoid_focal_loss,
bbox_reg_beta=cfg.MODEL.FOCAL.BBOX_REG_BETA,
regress_norm=cfg.MODEL.FOCAL.BBOX_REG_WEIGHT,
)
return loss_evaluator
def make_rpn_loss_evaluator(cfg, box_coder):
matcher = Matcher(
cfg.MODEL.RPN.FG_IOU_THRESHOLD,
cfg.MODEL.RPN.BG_IOU_THRESHOLD,
allow_low_quality_matches=True,
)
fg_bg_sampler = BalancedPositiveNegativeSampler(cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE, cfg.MODEL.RPN.POSITIVE_FRACTION)
loss_evaluator = RPNLossComputation(matcher, fg_bg_sampler, box_coder)
return loss_evaluator
def make_fcos_loss_evaluator(cfg):
loss_evaluator = FCOSLossComputation(cfg)
return loss_evaluator
def make_atss_loss_evaluator(cfg, box_coder):
loss_evaluator = ATSSLossComputation(cfg, box_coder)
return loss_evaluator