Spaces:
Paused
Paused
# Copyright (c) Facebook, Inc. and its affiliates. | |
from typing import List | |
import torch | |
from detectron2.layers import nonzero_tuple | |
# TODO: the name is too general | |
class Matcher: | |
""" | |
This class assigns to each predicted "element" (e.g., a box) a ground-truth | |
element. Each predicted element will have exactly zero or one matches; each | |
ground-truth element may be matched to zero or more predicted elements. | |
The matching is determined by the MxN match_quality_matrix, that characterizes | |
how well each (ground-truth, prediction)-pair match each other. For example, | |
if the elements are boxes, this matrix may contain box intersection-over-union | |
overlap values. | |
The matcher returns (a) a vector of length N containing the index of the | |
ground-truth element m in [0, M) that matches to prediction n in [0, N). | |
(b) a vector of length N containing the labels for each prediction. | |
""" | |
def __init__( | |
self, thresholds: List[float], labels: List[int], allow_low_quality_matches: bool = False | |
): | |
""" | |
Args: | |
thresholds (list): a list of thresholds used to stratify predictions | |
into levels. | |
labels (list): a list of values to label predictions belonging at | |
each level. A label can be one of {-1, 0, 1} signifying | |
{ignore, negative class, positive class}, respectively. | |
allow_low_quality_matches (bool): if True, produce additional matches | |
for predictions with maximum match quality lower than high_threshold. | |
See set_low_quality_matches_ for more details. | |
For example, | |
thresholds = [0.3, 0.5] | |
labels = [0, -1, 1] | |
All predictions with iou < 0.3 will be marked with 0 and | |
thus will be considered as false positives while training. | |
All predictions with 0.3 <= iou < 0.5 will be marked with -1 and | |
thus will be ignored. | |
All predictions with 0.5 <= iou will be marked with 1 and | |
thus will be considered as true positives. | |
""" | |
# Add -inf and +inf to first and last position in thresholds | |
thresholds = thresholds[:] | |
assert thresholds[0] > 0 | |
thresholds.insert(0, -float("inf")) | |
thresholds.append(float("inf")) | |
# Currently torchscript does not support all + generator | |
assert all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])]) | |
assert all([l in [-1, 0, 1] for l in labels]) | |
assert len(labels) == len(thresholds) - 1 | |
self.thresholds = thresholds | |
self.labels = labels | |
self.allow_low_quality_matches = allow_low_quality_matches | |
def __call__(self, match_quality_matrix): | |
""" | |
Args: | |
match_quality_matrix (Tensor[float]): an MxN tensor, containing the | |
pairwise quality between M ground-truth elements and N predicted | |
elements. All elements must be >= 0 (due to the us of `torch.nonzero` | |
for selecting indices in :meth:`set_low_quality_matches_`). | |
Returns: | |
matches (Tensor[int64]): a vector of length N, where matches[i] is a matched | |
ground-truth index in [0, M) | |
match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates | |
whether a prediction is a true or false positive or ignored | |
""" | |
assert match_quality_matrix.dim() == 2 | |
if match_quality_matrix.numel() == 0: | |
default_matches = match_quality_matrix.new_full( | |
(match_quality_matrix.size(1),), 0, dtype=torch.int64 | |
) | |
# When no gt boxes exist, we define IOU = 0 and therefore set labels | |
# to `self.labels[0]`, which usually defaults to background class 0 | |
# To choose to ignore instead, can make labels=[-1,0,-1,1] + set appropriate thresholds | |
default_match_labels = match_quality_matrix.new_full( | |
(match_quality_matrix.size(1),), self.labels[0], dtype=torch.int8 | |
) | |
return default_matches, default_match_labels | |
assert torch.all(match_quality_matrix >= 0) | |
# match_quality_matrix is M (gt) x N (predicted) | |
# Max over gt elements (dim 0) to find best gt candidate for each prediction | |
matched_vals, matches = match_quality_matrix.max(dim=0) | |
match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8) | |
for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]): | |
low_high = (matched_vals >= low) & (matched_vals < high) | |
match_labels[low_high] = l | |
if self.allow_low_quality_matches: | |
self.set_low_quality_matches_(match_labels, match_quality_matrix) | |
return matches, match_labels | |
def set_low_quality_matches_(self, match_labels, match_quality_matrix): | |
""" | |
Produce additional matches for predictions that have only low-quality matches. | |
Specifically, for each ground-truth G find the set of predictions that have | |
maximum overlap with it (including ties); for each prediction in that set, if | |
it is unmatched, then match it to the ground-truth G. | |
This function implements the RPN assignment case (i) in Sec. 3.1.2 of | |
:paper:`Faster R-CNN`. | |
""" | |
# For each gt, find the prediction with which it has highest quality | |
highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) | |
# Find the highest quality match available, even if it is low, including ties. | |
# Note that the matches qualities must be positive due to the use of | |
# `torch.nonzero`. | |
_, pred_inds_with_highest_quality = nonzero_tuple( | |
match_quality_matrix == highest_quality_foreach_gt[:, None] | |
) | |
# If an anchor was labeled positive only due to a low-quality match | |
# with gt_A, but it has larger overlap with gt_B, it's matched index will still be gt_B. | |
# This follows the implementation in Detectron, and is found to have no significant impact. | |
match_labels[pred_inds_with_highest_quality] = 1 | |