# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch class Matcher(object): """ This class assigns to each predicted "element" (e.g., a box) a ground-truth element. Each predicted element will have exactly zero or one matches; each ground-truth element may be assigned to zero or more predicted elements. Matching is based on the MxN match_quality_matrix, that characterizes how well each (ground-truth, predicted)-pair match. For example, if the elements are boxes, the matrix may contain box IoU overlap values. The matcher returns a tensor of size N containing the index of the ground-truth element m that matches to prediction n. If there is no match, a negative value is returned. """ BELOW_LOW_THRESHOLD = -1 BETWEEN_THRESHOLDS = -2 def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): """ Args: high_threshold (float): quality values greater than or equal to this value are candidate matches. low_threshold (float): a lower quality threshold used to stratify matches into three levels: 1) matches >= high_threshold 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold) 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold) allow_low_quality_matches (bool): if True, produce additional matches for predictions that have only low-quality match candidates. See set_low_quality_matches_ for more details. """ assert low_threshold <= high_threshold self.high_threshold = high_threshold self.low_threshold = low_threshold self.allow_low_quality_matches = allow_low_quality_matches def __call__(self, match_quality_matrix): """ Args: match_quality_matrix (Tensor[float]): an MxN tensor, containing the pairwise quality between M ground-truth elements and N predicted elements. Returns: matches (Tensor[int64]): an N tensor where N[i] is a matched gt in [0, M - 1] or a negative value indicating that prediction i could not be matched. """ if match_quality_matrix.numel() == 0: # empty targets or proposals not supported during training if match_quality_matrix.shape[0] == 0: # raise ValueError( # "No ground-truth boxes available for one of the images " # "during training") length = match_quality_matrix.size(1) device = match_quality_matrix.device return torch.ones(length, dtype=torch.int64, device=device) * -1 else: raise ValueError("No proposal boxes available for one of the images " "during training") # match_quality_matrix is M (gt) x N (predicted) # Max over gt elements (dim 0) to find best gt candidate for each prediction matched_vals, matches = match_quality_matrix.max(dim=0) if self.allow_low_quality_matches: all_matches = matches.clone() # Assign candidate matches with low quality to negative (unassigned) values below_low_threshold = matched_vals < self.low_threshold between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold) matches[below_low_threshold] = Matcher.BELOW_LOW_THRESHOLD matches[between_thresholds] = Matcher.BETWEEN_THRESHOLDS if self.allow_low_quality_matches: self.set_low_quality_matches_(matches, all_matches, match_quality_matrix) return matches def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): """ Produce additional matches for predictions that have only low-quality matches. Specifically, for each ground-truth find the set of predictions that have maximum overlap with it (including ties); for each prediction in that set, if it is unmatched, then match it to the ground-truth with which it has the highest quality value. """ # For each gt, find the prediction with which it has highest quality highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) # Find highest quality match available, even if it is low, including ties gt_pred_pairs_of_highest_quality = torch.nonzero(match_quality_matrix == highest_quality_foreach_gt[:, None]) # Example gt_pred_pairs_of_highest_quality: # tensor([[ 0, 39796], # [ 1, 32055], # [ 1, 32070], # [ 2, 39190], # [ 2, 40255], # [ 3, 40390], # [ 3, 41455], # [ 4, 45470], # [ 5, 45325], # [ 5, 46390]]) # Each row is a (gt index, prediction index) # Note how gt items 1, 2, 3, and 5 each have two ties pred_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1] matches[pred_inds_to_update] = all_matches[pred_inds_to_update]