Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import torch | |
class Matcher(object): | |
""" | |
This class assigns to each predicted "element" (e.g., a box) a ground-truth | |
element. Each predicted element will have exactly zero or one matches; each | |
ground-truth element may be assigned to zero or more predicted elements. | |
Matching is based on the MxN match_quality_matrix, that characterizes how well | |
each (ground-truth, predicted)-pair match. For example, if the elements are | |
boxes, the matrix may contain box IoU overlap values. | |
The matcher returns a tensor of size N containing the index of the ground-truth | |
element m that matches to prediction n. If there is no match, a negative value | |
is returned. | |
""" | |
BELOW_LOW_THRESHOLD = -1 | |
BETWEEN_THRESHOLDS = -2 | |
def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): | |
""" | |
Args: | |
high_threshold (float): quality values greater than or equal to | |
this value are candidate matches. | |
low_threshold (float): a lower quality threshold used to stratify | |
matches into three levels: | |
1) matches >= high_threshold | |
2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold) | |
3) BELOW_LOW_THRESHOLD matches in [0, low_threshold) | |
allow_low_quality_matches (bool): if True, produce additional matches | |
for predictions that have only low-quality match candidates. See | |
set_low_quality_matches_ for more details. | |
""" | |
assert low_threshold <= high_threshold | |
self.high_threshold = high_threshold | |
self.low_threshold = low_threshold | |
self.allow_low_quality_matches = allow_low_quality_matches | |
def __call__(self, match_quality_matrix): | |
""" | |
Args: | |
match_quality_matrix (Tensor[float]): an MxN tensor, containing the | |
pairwise quality between M ground-truth elements and N predicted elements. | |
Returns: | |
matches (Tensor[int64]): an N tensor where N[i] is a matched gt in | |
[0, M - 1] or a negative value indicating that prediction i could not | |
be matched. | |
""" | |
if match_quality_matrix.numel() == 0: | |
# handle empty case | |
device = match_quality_matrix.device | |
return torch.empty((0,), dtype=torch.int64, device=device) | |
# match_quality_matrix is M (gt) x N (predicted) | |
# Max over gt elements (dim 0) to find best gt candidate for each prediction | |
matched_vals, matches = match_quality_matrix.max(dim=0) | |
if self.allow_low_quality_matches: | |
all_matches = matches.clone() | |
# Assign candidate matches with low quality to negative (unassigned) values | |
below_low_threshold = matched_vals < self.low_threshold | |
between_thresholds = (matched_vals >= self.low_threshold) & ( | |
matched_vals < self.high_threshold | |
) | |
matches[below_low_threshold] = Matcher.BELOW_LOW_THRESHOLD | |
matches[between_thresholds] = Matcher.BETWEEN_THRESHOLDS | |
if self.allow_low_quality_matches: | |
self.set_low_quality_matches_(matches, all_matches, match_quality_matrix) | |
return matches | |
def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): | |
""" | |
Produce additional matches for predictions that have only low-quality matches. | |
Specifically, for each ground-truth find the set of predictions that have | |
maximum overlap with it (including ties); for each prediction in that set, if | |
it is unmatched, then match it to the ground-truth with which it has the highest | |
quality value. | |
""" | |
# For each gt, find the prediction with which it has highest quality | |
highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) | |
# Find highest quality match available, even if it is low, including ties | |
gt_pred_pairs_of_highest_quality = torch.nonzero( | |
match_quality_matrix == highest_quality_foreach_gt[:, None] | |
) | |
# Example gt_pred_pairs_of_highest_quality: | |
# tensor([[ 0, 39796], | |
# [ 1, 32055], | |
# [ 1, 32070], | |
# [ 2, 39190], | |
# [ 2, 40255], | |
# [ 3, 40390], | |
# [ 3, 41455], | |
# [ 4, 45470], | |
# [ 5, 45325], | |
# [ 5, 46390]]) | |
# Each row is a (gt index, prediction index) | |
# Note how gt items 1, 2, 3, and 5 each have two ties | |
pred_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1] | |
matches[pred_inds_to_update] = all_matches[pred_inds_to_update] | |