Realcat
add: GIM (https://github.com/xuelunshen/gim)
4d4dd90
raw
history blame
No virus
3.81 kB
"""
Nearest neighbor matcher for normalized descriptors.
Optionally apply the mutual check and threshold the distance or ratio.
"""
import logging
import torch
import torch.nn.functional as F
from ..base_model import BaseModel
from ..utils.metrics import matcher_metrics
@torch.no_grad()
def find_nn(sim, ratio_thresh, distance_thresh):
sim_nn, ind_nn = sim.topk(2 if ratio_thresh else 1, dim=-1, largest=True)
dist_nn = 2 * (1 - sim_nn)
mask = torch.ones(ind_nn.shape[:-1], dtype=torch.bool, device=sim.device)
if ratio_thresh:
mask = mask & (dist_nn[..., 0] <= (ratio_thresh**2) * dist_nn[..., 1])
if distance_thresh:
mask = mask & (dist_nn[..., 0] <= distance_thresh**2)
matches = torch.where(mask, ind_nn[..., 0], ind_nn.new_tensor(-1))
return matches
def mutual_check(m0, m1):
inds0 = torch.arange(m0.shape[-1], device=m0.device)
inds1 = torch.arange(m1.shape[-1], device=m1.device)
loop0 = torch.gather(m1, -1, torch.where(m0 > -1, m0, m0.new_tensor(0)))
loop1 = torch.gather(m0, -1, torch.where(m1 > -1, m1, m1.new_tensor(0)))
m0_new = torch.where((m0 > -1) & (inds0 == loop0), m0, m0.new_tensor(-1))
m1_new = torch.where((m1 > -1) & (inds1 == loop1), m1, m1.new_tensor(-1))
return m0_new, m1_new
class NearestNeighborMatcher(BaseModel):
default_conf = {
"ratio_thresh": None,
"distance_thresh": None,
"mutual_check": True,
"loss": None,
}
required_data_keys = ["descriptors0", "descriptors1"]
def _init(self, conf):
if conf.loss == "N_pair":
temperature = torch.nn.Parameter(torch.tensor(1.0))
self.register_parameter("temperature", temperature)
def _forward(self, data):
sim = torch.einsum("bnd,bmd->bnm", data["descriptors0"], data["descriptors1"])
matches0 = find_nn(sim, self.conf.ratio_thresh, self.conf.distance_thresh)
matches1 = find_nn(
sim.transpose(1, 2), self.conf.ratio_thresh, self.conf.distance_thresh
)
if self.conf.mutual_check:
matches0, matches1 = mutual_check(matches0, matches1)
b, m, n = sim.shape
la = sim.new_zeros(b, m + 1, n + 1)
la[:, :-1, :-1] = F.log_softmax(sim, -1) + F.log_softmax(sim, -2)
mscores0 = (matches0 > -1).float()
mscores1 = (matches1 > -1).float()
return {
"matches0": matches0,
"matches1": matches1,
"matching_scores0": mscores0,
"matching_scores1": mscores1,
"similarity": sim,
"log_assignment": la,
}
def loss(self, pred, data):
losses = {}
if self.conf.loss == "N_pair":
sim = pred["similarity"]
if torch.any(sim > (1.0 + 1e-6)):
logging.warning(f"Similarity larger than 1, max={sim.max()}")
scores = torch.sqrt(torch.clamp(2 * (1 - sim), min=1e-6))
scores = self.temperature * (2 - scores)
assert not torch.any(torch.isnan(scores)), torch.any(torch.isnan(sim))
prob0 = torch.nn.functional.log_softmax(scores, 2)
prob1 = torch.nn.functional.log_softmax(scores, 1)
assignment = data["gt_assignment"].float()
num = torch.max(assignment.sum((1, 2)), assignment.new_tensor(1))
nll0 = (prob0 * assignment).sum((1, 2)) / num
nll1 = (prob1 * assignment).sum((1, 2)) / num
nll = -(nll0 + nll1) / 2
losses["n_pair_nll"] = losses["total"] = nll
losses["num_matchable"] = num
losses["n_pair_temperature"] = self.temperature[None]
else:
raise NotImplementedError
metrics = {} if self.training else matcher_metrics(pred, data)
return losses, metrics