|
|
|
|
|
|
|
|
|
from pdb import set_trace as bb |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class NghSampler (nn.Module): |
|
""" Given dense feature maps and pixel-dense flow, |
|
compute a subset of all correspondences and return their scores and labels. |
|
|
|
Distance to GT => 0 ... pos_d ... neg_d ... ngh |
|
Pixel label => + + + + + + 0 0 - - - - - - - |
|
|
|
Subsample on query side: if > 0, regular grid |
|
< 0, random points |
|
In both cases, the number of query points is = W*H/subq**2 |
|
""" |
|
def __init__(self, ngh, subq=-8, subd=1, pos_d=2, neg_d=4, border=16, subd_neg=-8): |
|
nn.Module.__init__(self) |
|
assert 0 <= pos_d < neg_d <= (ngh if ngh else 99) |
|
self.ngh = ngh |
|
self.pos_d = pos_d |
|
self.neg_d = neg_d |
|
assert subd <= ngh or ngh == 0 |
|
assert subq != 0 |
|
self.sub_q = subq |
|
self.sub_d = subd |
|
self.sub_d_neg = subd_neg |
|
if border is None: border = ngh |
|
assert border >= ngh, 'border has to be larger than ngh' |
|
self.border = border |
|
self.precompute_offsets() |
|
|
|
def precompute_offsets(self): |
|
pos_d2 = self.pos_d**2 |
|
neg_d2 = self.neg_d**2 |
|
rad2 = self.ngh**2 |
|
rad = (self.ngh//self.sub_d) * self.ngh |
|
pos = [] |
|
neg = [] |
|
for j in range(-rad, rad+1, self.sub_d): |
|
for i in range(-rad, rad+1, self.sub_d): |
|
d2 = i*i + j*j |
|
if d2 <= pos_d2: |
|
pos.append( (i,j) ) |
|
elif neg_d2 <= d2 <= rad2: |
|
neg.append( (i,j) ) |
|
|
|
self.register_buffer('pos_offsets', torch.LongTensor(pos).view(-1,2).t()) |
|
self.register_buffer('neg_offsets', torch.LongTensor(neg).view(-1,2).t()) |
|
|
|
def gen_grid(self, step, aflow): |
|
B, two, H, W = aflow.shape |
|
dev = aflow.device |
|
b1 = torch.arange(B, device=dev) |
|
if step > 0: |
|
|
|
x1 = torch.arange(self.border, W-self.border, step, device=dev) |
|
y1 = torch.arange(self.border, H-self.border, step, device=dev) |
|
H1, W1 = len(y1), len(x1) |
|
shape = (B, H1, W1) |
|
x1 = x1[None,None,:].expand(B,H1,W1).reshape(-1) |
|
y1 = y1[None,:,None].expand(B,H1,W1).reshape(-1) |
|
b1 = b1[:,None,None].expand(B,H1,W1).reshape(-1) |
|
else: |
|
|
|
n = (H - 2*self.border) * (W - 2*self.border) // step**2 |
|
x1 = torch.randint(self.border, W-self.border, (n,), device=dev) |
|
y1 = torch.randint(self.border, H-self.border, (n,), device=dev) |
|
x1 = x1[None,:].expand(B,n).reshape(-1) |
|
y1 = y1[None,:].expand(B,n).reshape(-1) |
|
b1 = b1[:,None].expand(B,n).reshape(-1) |
|
shape = (B, n) |
|
return b1, y1, x1, shape |
|
|
|
def forward(self, feats, confs, aflow, **kw): |
|
B, two, H, W = aflow.shape |
|
assert two == 2, bb() |
|
feat1, conf1 = feats[0], (confs[0] if confs else None) |
|
feat2, conf2 = feats[1], (confs[1] if confs else None) |
|
|
|
|
|
b_, y1, x1, shape = self.gen_grid(self.sub_q, aflow) |
|
|
|
|
|
feat1 = feat1[b_, :, y1, x1] |
|
qconf = conf1[b_, :, y1, x1].view(shape) if confs else None |
|
|
|
|
|
xy2 = (aflow[b_, :, y1, x1] + 0.5).long().t() |
|
mask = (0 <= xy2[0]) * (0 <= xy2[1]) * (xy2[0] < W) * (xy2[1] < H) |
|
mask = mask.view(shape) |
|
|
|
def clamp(xy): |
|
torch.clamp(xy[0], 0, W-1, out=xy[0]) |
|
torch.clamp(xy[1], 0, H-1, out=xy[1]) |
|
return xy |
|
|
|
|
|
xy2p = clamp(xy2[:,None,:] + self.pos_offsets[:,:,None]) |
|
pscores = torch.einsum('nk,ink->ni', feat1, feat2[b_, :, xy2p[1], xy2p[0]]) |
|
|
|
|
|
xy2n = clamp(xy2[:,None,:] + self.neg_offsets[:,:,None]) |
|
nscores = torch.einsum('nk,ink->ni', feat1, feat2[b_, :, xy2n[1], xy2n[0]]) |
|
|
|
if self.sub_d_neg: |
|
|
|
b3, y3, x3 = self.gen_grid(self.sub_d_neg, aflow)[:3] |
|
distractors = feat2[b3, :, y3, x3] |
|
dscores = torch.einsum('nk,ik->ni', feat1, distractors) |
|
del distractors |
|
|
|
|
|
x2, y2 = xy2 = xy2.float() |
|
xy3 = torch.stack((x3,y3)).float() |
|
dis2 = torch.cdist((xy2+b_*512).T, (xy3+b3*512).T, compute_mode='donot_use_mm_for_euclid_dist') |
|
dscores[dis2 < self.neg_d] = 0 |
|
|
|
scores = torch.cat((pscores, nscores, dscores), dim=1) |
|
|
|
gt = scores.new_zeros(scores.shape, dtype=torch.uint8) |
|
gt[:, :pscores.shape[1]] = 1 |
|
|
|
return scores, gt, mask, qconf |
|
|