zdou0830's picture
desco
749745d
raw
history blame
No virus
7.38 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch.nn import functional as F
from maskrcnn_benchmark.layers import smooth_l1_loss
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
from maskrcnn_benchmark.modeling.matcher import Matcher
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from maskrcnn_benchmark.modeling.balanced_positive_negative_sampler import BalancedPositiveNegativeSampler
from maskrcnn_benchmark.modeling.utils import cat
from maskrcnn_benchmark.utils.amp import custom_fwd, custom_bwd
class FastRCNNLossComputation(object):
"""
Computes the loss for Faster R-CNN.
Also supports FPN
"""
def __init__(self, proposal_matcher, fg_bg_sampler, box_coder):
"""
Arguments:
proposal_matcher (Matcher)
fg_bg_sampler (BalancedPositiveNegativeSampler)
box_coder (BoxCoder)
"""
self.proposal_matcher = proposal_matcher
self.fg_bg_sampler = fg_bg_sampler
self.box_coder = box_coder
def match_targets_to_proposals(self, proposal, target):
match_quality_matrix = boxlist_iou(target, proposal)
matched_idxs = self.proposal_matcher(match_quality_matrix)
# Fast RCNN only need "labels" field for selecting the targets
target = target.copy_with_fields("labels")
# get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
if len(target):
matched_targets = target[matched_idxs.clamp(min=0)]
else:
device = target.get_field("labels").device
dtype = target.get_field("labels").dtype
labels = torch.zeros_like(matched_idxs, dtype=dtype, device=device)
matched_targets = target
matched_targets.add_field("labels", labels)
matched_targets.add_field("matched_idxs", matched_idxs)
return matched_targets
def prepare_targets(self, proposals, targets):
labels = []
regression_targets = []
for proposals_per_image, targets_per_image in zip(proposals, targets):
matched_targets = self.match_targets_to_proposals(proposals_per_image, targets_per_image)
matched_idxs = matched_targets.get_field("matched_idxs")
labels_per_image = matched_targets.get_field("labels")
labels_per_image = labels_per_image.to(dtype=torch.int64)
# Label background (below the low threshold)
bg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_inds] = 0
# Label ignore proposals (between low and high thresholds)
ignore_inds = matched_idxs == Matcher.BETWEEN_THRESHOLDS
labels_per_image[ignore_inds] = -1 # -1 is ignored by sampler
# compute regression targets
if not matched_targets.bbox.shape[0]:
zeros = torch.zeros_like(labels_per_image, dtype=torch.float32)
regression_targets_per_image = torch.stack((zeros, zeros, zeros, zeros), dim=1)
else:
regression_targets_per_image = self.box_coder.encode(matched_targets.bbox, proposals_per_image.bbox)
labels.append(labels_per_image)
regression_targets.append(regression_targets_per_image)
return labels, regression_targets
def subsample(self, proposals, targets):
"""
This method performs the positive/negative sampling, and return
the sampled proposals.
Note: this function keeps a state.
Arguments:
proposals (list[BoxList])
targets (list[BoxList])
"""
labels, regression_targets = self.prepare_targets(proposals, targets)
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
proposals = list(proposals)
# add corresponding label and regression_targets information to the bounding boxes
for labels_per_image, regression_targets_per_image, proposals_per_image in zip(
labels, regression_targets, proposals
):
proposals_per_image.add_field("labels", labels_per_image)
proposals_per_image.add_field("regression_targets", regression_targets_per_image)
# distributed sampled proposals, that were obtained on all feature maps
# concatenated via the fg_bg_sampler, into individual feature map levels
for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
proposals_per_image = proposals[img_idx][img_sampled_inds]
proposals[img_idx] = proposals_per_image
self._proposals = proposals
return proposals
@custom_fwd(cast_inputs=torch.float32)
def __call__(self, class_logits, box_regression):
"""
Computes the loss for Faster R-CNN.
This requires that the subsample method has been called beforehand.
Arguments:
class_logits (list[Tensor])
box_regression (list[Tensor])
Returns:
classification_loss (Tensor)
box_loss (Tensor)
"""
class_logits = cat(class_logits, dim=0)
box_regression = cat(box_regression, dim=0)
device = class_logits.device
if not hasattr(self, "_proposals"):
raise RuntimeError("subsample needs to be called before")
proposals = self._proposals
labels = cat([proposal.get_field("labels") for proposal in proposals], dim=0)
regression_targets = cat([proposal.get_field("regression_targets") for proposal in proposals], dim=0)
classification_loss = F.cross_entropy(class_logits, labels)
# get indices that correspond to the regression targets for
# the corresponding ground truth labels, to be used with
# advanced indexing
sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1)
labels_pos = labels[sampled_pos_inds_subset]
map_inds = 4 * labels_pos[:, None] + torch.tensor([0, 1, 2, 3], device=device)
box_loss = smooth_l1_loss(
box_regression[sampled_pos_inds_subset[:, None], map_inds],
regression_targets[sampled_pos_inds_subset],
size_average=False,
beta=1,
)
box_loss = box_loss / labels.numel()
return classification_loss, box_loss
def make_roi_box_loss_evaluator(cfg):
matcher = Matcher(
cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD,
cfg.MODEL.ROI_HEADS.BG_IOU_THRESHOLD,
allow_low_quality_matches=False,
)
bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS
box_coder = BoxCoder(weights=bbox_reg_weights)
fg_bg_sampler = BalancedPositiveNegativeSampler(
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION
)
loss_evaluator = FastRCNNLossComputation(matcher, fg_bg_sampler, box_coder)
return loss_evaluator