""" Nearest neighbor matcher for normalized descriptors. Optionally apply the mutual check and threshold the distance or ratio. """ import logging import torch import torch.nn.functional as F from ..base_model import BaseModel from ..utils.metrics import matcher_metrics @torch.no_grad() def find_nn(sim, ratio_thresh, distance_thresh): sim_nn, ind_nn = sim.topk(2 if ratio_thresh else 1, dim=-1, largest=True) dist_nn = 2 * (1 - sim_nn) mask = torch.ones(ind_nn.shape[:-1], dtype=torch.bool, device=sim.device) if ratio_thresh: mask = mask & (dist_nn[..., 0] <= (ratio_thresh**2) * dist_nn[..., 1]) if distance_thresh: mask = mask & (dist_nn[..., 0] <= distance_thresh**2) matches = torch.where(mask, ind_nn[..., 0], ind_nn.new_tensor(-1)) return matches def mutual_check(m0, m1): inds0 = torch.arange(m0.shape[-1], device=m0.device) inds1 = torch.arange(m1.shape[-1], device=m1.device) loop0 = torch.gather(m1, -1, torch.where(m0 > -1, m0, m0.new_tensor(0))) loop1 = torch.gather(m0, -1, torch.where(m1 > -1, m1, m1.new_tensor(0))) m0_new = torch.where((m0 > -1) & (inds0 == loop0), m0, m0.new_tensor(-1)) m1_new = torch.where((m1 > -1) & (inds1 == loop1), m1, m1.new_tensor(-1)) return m0_new, m1_new class NearestNeighborMatcher(BaseModel): default_conf = { "ratio_thresh": None, "distance_thresh": None, "mutual_check": True, "loss": None, } required_data_keys = ["descriptors0", "descriptors1"] def _init(self, conf): if conf.loss == "N_pair": temperature = torch.nn.Parameter(torch.tensor(1.0)) self.register_parameter("temperature", temperature) def _forward(self, data): sim = torch.einsum("bnd,bmd->bnm", data["descriptors0"], data["descriptors1"]) matches0 = find_nn(sim, self.conf.ratio_thresh, self.conf.distance_thresh) matches1 = find_nn( sim.transpose(1, 2), self.conf.ratio_thresh, self.conf.distance_thresh ) if self.conf.mutual_check: matches0, matches1 = mutual_check(matches0, matches1) b, m, n = sim.shape la = sim.new_zeros(b, m + 1, n + 1) la[:, :-1, :-1] = F.log_softmax(sim, -1) + F.log_softmax(sim, -2) mscores0 = (matches0 > -1).float() mscores1 = (matches1 > -1).float() return { "matches0": matches0, "matches1": matches1, "matching_scores0": mscores0, "matching_scores1": mscores1, "similarity": sim, "log_assignment": la, } def loss(self, pred, data): losses = {} if self.conf.loss == "N_pair": sim = pred["similarity"] if torch.any(sim > (1.0 + 1e-6)): logging.warning(f"Similarity larger than 1, max={sim.max()}") scores = torch.sqrt(torch.clamp(2 * (1 - sim), min=1e-6)) scores = self.temperature * (2 - scores) assert not torch.any(torch.isnan(scores)), torch.any(torch.isnan(sim)) prob0 = torch.nn.functional.log_softmax(scores, 2) prob1 = torch.nn.functional.log_softmax(scores, 1) assignment = data["gt_assignment"].float() num = torch.max(assignment.sum((1, 2)), assignment.new_tensor(1)) nll0 = (prob0 * assignment).sum((1, 2)) / num nll1 = (prob1 * assignment).sum((1, 2)) / num nll = -(nll0 + nll1) / 2 losses["n_pair_nll"] = losses["total"] = nll losses["num_matchable"] = num losses["n_pair_temperature"] = self.temperature[None] else: raise NotImplementedError metrics = {} if self.training else matcher_metrics(pred, data) return losses, metrics