|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from kornia.geometry.subpix import dsnt |
|
from kornia.utils.grid import create_meshgrid |
|
|
|
|
|
class FineMatching(nn.Module): |
|
"""FineMatching with s2d paradigm""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, feat_f0, feat_f1, data): |
|
""" |
|
Args: |
|
feat0 (torch.Tensor): [M, WW, C] |
|
feat1 (torch.Tensor): [M, WW, C] |
|
data (dict) |
|
Update: |
|
data (dict):{ |
|
'expec_f' (torch.Tensor): [M, 3], |
|
'mkpts0_f' (torch.Tensor): [M, 2], |
|
'mkpts1_f' (torch.Tensor): [M, 2]} |
|
""" |
|
M, WW, C = feat_f0.shape |
|
W = int(math.sqrt(WW)) |
|
scale = data["hw0_i"][0] / data["hw0_f"][0] |
|
self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale |
|
|
|
|
|
if M == 0: |
|
assert ( |
|
self.training == False |
|
), "M is always >0, when training, see coarse_matching.py" |
|
|
|
data.update( |
|
{ |
|
"expec_f": torch.empty(0, 3, device=feat_f0.device), |
|
"mkpts0_f": data["mkpts0_c"], |
|
"mkpts1_f": data["mkpts1_c"], |
|
} |
|
) |
|
return |
|
|
|
feat_f0_picked = feat_f0[:, WW // 2, :] |
|
|
|
sim_matrix = torch.einsum("mc,mrc->mr", feat_f0_picked, feat_f1) |
|
softmax_temp = 1.0 / C**0.5 |
|
heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1) |
|
feat_f1_picked = (feat_f1 * heatmap.unsqueeze(-1)).sum(dim=1) |
|
heatmap = heatmap.view(-1, W, W) |
|
|
|
|
|
coords1_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[ |
|
0 |
|
] |
|
grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape( |
|
1, -1, 2 |
|
) |
|
|
|
|
|
var = ( |
|
torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) |
|
- coords1_normalized**2 |
|
) |
|
std = torch.sum( |
|
torch.sqrt(torch.clamp(var, min=1e-10)), -1 |
|
) |
|
|
|
|
|
data.update( |
|
{ |
|
"expec_f": torch.cat([coords1_normalized, std.unsqueeze(1)], -1), |
|
"descriptors0": feat_f0_picked.detach(), |
|
"descriptors1": feat_f1_picked.detach(), |
|
} |
|
) |
|
|
|
|
|
self.get_fine_match(coords1_normalized, data) |
|
|
|
@torch.no_grad() |
|
def get_fine_match(self, coords1_normed, data): |
|
W, WW, C, scale = self.W, self.WW, self.C, self.scale |
|
|
|
|
|
|
|
mkpts0_f = data[ |
|
"mkpts0_c" |
|
] |
|
scale1 = scale * data["scale1"][data["b_ids"]] if "scale1" in data else scale |
|
mkpts1_f = ( |
|
data["mkpts1_c"] |
|
+ (coords1_normed * (W // 2) * scale1)[: len(data["mconf"])] |
|
) |
|
|
|
data.update({"mkpts0_f": mkpts0_f, "mkpts1_f": mkpts1_f}) |
|
|