File size: 5,027 Bytes
3ef85e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use
from pdb import set_trace as bb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class NghSampler (nn.Module):
""" Given dense feature maps and pixel-dense flow,
compute a subset of all correspondences and return their scores and labels.
Distance to GT => 0 ... pos_d ... neg_d ... ngh
Pixel label => + + + + + + 0 0 - - - - - - -
Subsample on query side: if > 0, regular grid
< 0, random points
In both cases, the number of query points is = W*H/subq**2
"""
def __init__(self, ngh, subq=-8, subd=1, pos_d=2, neg_d=4, border=16, subd_neg=-8):
nn.Module.__init__(self)
assert 0 <= pos_d < neg_d <= (ngh if ngh else 99)
self.ngh = ngh
self.pos_d = pos_d
self.neg_d = neg_d
assert subd <= ngh or ngh == 0
assert subq != 0
self.sub_q = subq
self.sub_d = subd
self.sub_d_neg = subd_neg
if border is None: border = ngh
assert border >= ngh, 'border has to be larger than ngh'
self.border = border
self.precompute_offsets()
def precompute_offsets(self):
pos_d2 = self.pos_d**2
neg_d2 = self.neg_d**2
rad2 = self.ngh**2
rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple
pos = []
neg = []
for j in range(-rad, rad+1, self.sub_d):
for i in range(-rad, rad+1, self.sub_d):
d2 = i*i + j*j
if d2 <= pos_d2:
pos.append( (i,j) )
elif neg_d2 <= d2 <= rad2:
neg.append( (i,j) )
self.register_buffer('pos_offsets', torch.LongTensor(pos).view(-1,2).t())
self.register_buffer('neg_offsets', torch.LongTensor(neg).view(-1,2).t())
def gen_grid(self, step, aflow):
B, two, H, W = aflow.shape
dev = aflow.device
b1 = torch.arange(B, device=dev)
if step > 0:
# regular grid
x1 = torch.arange(self.border, W-self.border, step, device=dev)
y1 = torch.arange(self.border, H-self.border, step, device=dev)
H1, W1 = len(y1), len(x1)
shape = (B, H1, W1)
x1 = x1[None,None,:].expand(B,H1,W1).reshape(-1)
y1 = y1[None,:,None].expand(B,H1,W1).reshape(-1)
b1 = b1[:,None,None].expand(B,H1,W1).reshape(-1)
else:
# randomly spread
n = (H - 2*self.border) * (W - 2*self.border) // step**2
x1 = torch.randint(self.border, W-self.border, (n,), device=dev)
y1 = torch.randint(self.border, H-self.border, (n,), device=dev)
x1 = x1[None,:].expand(B,n).reshape(-1)
y1 = y1[None,:].expand(B,n).reshape(-1)
b1 = b1[:,None].expand(B,n).reshape(-1)
shape = (B, n)
return b1, y1, x1, shape
def forward(self, feats, confs, aflow, **kw):
B, two, H, W = aflow.shape
assert two == 2, bb()
feat1, conf1 = feats[0], (confs[0] if confs else None)
feat2, conf2 = feats[1], (confs[1] if confs else None)
# positions in the first image
b_, y1, x1, shape = self.gen_grid(self.sub_q, aflow)
# sample features from first image
feat1 = feat1[b_, :, y1, x1]
qconf = conf1[b_, :, y1, x1].view(shape) if confs else None
#sample GT from second image
xy2 = (aflow[b_, :, y1, x1] + 0.5).long().t()
mask = (0 <= xy2[0]) * (0 <= xy2[1]) * (xy2[0] < W) * (xy2[1] < H)
mask = mask.view(shape)
def clamp(xy):
torch.clamp(xy[0], 0, W-1, out=xy[0])
torch.clamp(xy[1], 0, H-1, out=xy[1])
return xy
# compute positive scores
xy2p = clamp(xy2[:,None,:] + self.pos_offsets[:,:,None])
pscores = torch.einsum('nk,ink->ni', feat1, feat2[b_, :, xy2p[1], xy2p[0]])
# compute negative scores
xy2n = clamp(xy2[:,None,:] + self.neg_offsets[:,:,None])
nscores = torch.einsum('nk,ink->ni', feat1, feat2[b_, :, xy2n[1], xy2n[0]])
if self.sub_d_neg:
# add distractors from a grid
b3, y3, x3 = self.gen_grid(self.sub_d_neg, aflow)[:3]
distractors = feat2[b3, :, y3, x3]
dscores = torch.einsum('nk,ik->ni', feat1, distractors)
del distractors
# remove scores that corresponds to positives or nulls
x2, y2 = xy2 = xy2.float()
xy3 = torch.stack((x3,y3)).float()
dis2 = torch.cdist((xy2+b_*512).T, (xy3+b3*512).T, compute_mode='donot_use_mm_for_euclid_dist')
dscores[dis2 < self.neg_d] = 0
scores = torch.cat((pscores, nscores, dscores), dim=1)
gt = scores.new_zeros(scores.shape, dtype=torch.uint8)
gt[:, :pscores.shape[1]] = 1
return scores, gt, mask, qconf
|