3v324v23's picture
history blame
No virus
6.12 kB
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
This file contains specific functions for computing losses on the RPN
import torch
from torch.nn import functional as F
from ..balanced_positive_negative_sampler import BalancedPositiveNegativeSampler
from ..utils import cat
from maskrcnn_benchmark.layers import smooth_l1_loss
from maskrcnn_benchmark.modeling.matcher import Matcher
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
class RPNLossComputation(object):
This class computes the RPN loss.
def __init__(self, proposal_matcher, fg_bg_sampler, box_coder):
proposal_matcher (Matcher)
fg_bg_sampler (BalancedPositiveNegativeSampler)
box_coder (BoxCoder)
# self.target_preparator = target_preparator
self.proposal_matcher = proposal_matcher
self.fg_bg_sampler = fg_bg_sampler
self.box_coder = box_coder
def match_targets_to_anchors(self, anchor, target):
match_quality_matrix = boxlist_iou(target, anchor)
matched_idxs = self.proposal_matcher(match_quality_matrix)
# RPN doesn't need any fields from target
# for creating the labels, so clear them all
target = target.copy_with_fields([])
# get the targets corresponding GT for each anchor
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
matched_targets = target[matched_idxs.clamp(min=0)]
matched_targets.add_field("matched_idxs", matched_idxs)
return matched_targets
def prepare_targets(self, anchors, targets):
labels = []
regression_targets = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
matched_targets = self.match_targets_to_anchors(
anchors_per_image, targets_per_image
matched_idxs = matched_targets.get_field("matched_idxs")
labels_per_image = matched_idxs >= 0
labels_per_image = labels_per_image.to(dtype=torch.float32)
# discard anchors that go out of the boundaries of the image
labels_per_image[~anchors_per_image.get_field("visibility")] = -1
# discard indices that are between thresholds
inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = -1
# compute regression targets
regression_targets_per_image = self.box_coder.encode(
matched_targets.bbox, anchors_per_image.bbox
return labels, regression_targets
def __call__(self, anchors, objectness, box_regression, targets):
anchors (list[BoxList])
objectness (list[Tensor])
box_regression (list[Tensor])
targets (list[BoxList])
objectness_loss (Tensor)
box_loss (Tensor
anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors]
labels, regression_targets = self.prepare_targets(anchors, targets)
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
objectness_flattened = []
box_regression_flattened = []
# for each feature level, permute the outputs to make them be in the
# same format as the labels. Note that the labels are computed for
# all feature levels concatenated, so we keep the same representation
# for the objectness and the box_regression
for objectness_per_level, box_regression_per_level in zip(
objectness, box_regression
N, A, H, W = objectness_per_level.shape
objectness_per_level = objectness_per_level.permute(0, 2, 3, 1).reshape(
N, -1
box_regression_per_level = box_regression_per_level.view(N, -1, 4, H, W)
box_regression_per_level = box_regression_per_level.permute(0, 3, 4, 1, 2)
box_regression_per_level = box_regression_per_level.reshape(N, -1, 4)
# concatenate on the first dimension (representing the feature levels), to
# take into account the way the labels were generated (with all feature maps
# being concatenated as well)
objectness = cat(objectness_flattened, dim=1).reshape(-1)
box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4)
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)
box_loss = smooth_l1_loss(
beta=1.0 / 9,
) / (sampled_inds.numel())
objectness_loss = F.binary_cross_entropy_with_logits(
objectness[sampled_inds], labels[sampled_inds]
return objectness_loss, box_loss
def make_rpn_loss_evaluator(cfg, box_coder):
matcher = Matcher(
fg_bg_sampler = BalancedPositiveNegativeSampler(
loss_evaluator = RPNLossComputation(matcher, fg_bg_sampler, box_coder)
return loss_evaluator