import torch from ..utils.base_model import BaseModel import numpy as np # borrow from dedode def dual_softmax_matcher( desc_A: tuple["B", "C", "N"], desc_B: tuple["B", "C", "M"], threshold=0.1, inv_temperature=20, normalize=True, ): B, C, N = desc_A.shape if len(desc_A.shape) < 3: desc_A, desc_B = desc_A[None], desc_B[None] if normalize: desc_A = desc_A / desc_A.norm(dim=1, keepdim=True) desc_B = desc_B / desc_B.norm(dim=1, keepdim=True) sim = ( torch.einsum("b c n, b c m -> b n m", desc_A, desc_B) * inv_temperature ) P = sim.softmax(dim=-2) * sim.softmax(dim=-1) mask = torch.nonzero( (P == P.max(dim=-1, keepdim=True).values) * (P == P.max(dim=-2, keepdim=True).values) * (P > threshold) ) mask = mask.cpu().numpy() matches0 = np.ones((B, P.shape[-2]), dtype=int) * (-1) scores0 = np.zeros((B, P.shape[-2]), dtype=float) matches0[:, mask[:, 1]] = mask[:, 2] tmp_P = P.cpu().numpy() scores0[:, mask[:, 1]] = tmp_P[mask[:, 0], mask[:, 1], mask[:, 2]] matches0 = torch.from_numpy(matches0).to(P.device) scores0 = torch.from_numpy(scores0).to(P.device) return matches0, scores0 class DualSoftMax(BaseModel): default_conf = { "match_threshold": 0.2, "inv_temperature": 20, } # shape: B x DIM x M required_inputs = ["descriptors0", "descriptors1"] def _init(self, conf): pass def _forward(self, data): if ( data["descriptors0"].size(-1) == 0 or data["descriptors1"].size(-1) == 0 ): matches0 = torch.full( data["descriptors0"].shape[:2], -1, device=data["descriptors0"].device, ) return { "matches0": matches0, "matching_scores0": torch.zeros_like(matches0), } matches0, scores0 = dual_softmax_matcher( data["descriptors0"], data["descriptors1"], threshold=self.conf["match_threshold"], inv_temperature=self.conf["inv_temperature"], ) return { "matches0": matches0, # 1 x M "matching_scores0": scores0, }