PUMP / core /losses /ap_loss_sampler.py
Philippe Weinzaepfel
huggingface demo
3ef85e9
raw
history blame
5.03 kB
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use
from pdb import set_trace as bb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class NghSampler (nn.Module):
""" Given dense feature maps and pixel-dense flow,
compute a subset of all correspondences and return their scores and labels.
Distance to GT => 0 ... pos_d ... neg_d ... ngh
Pixel label => + + + + + + 0 0 - - - - - - -
Subsample on query side: if > 0, regular grid
< 0, random points
In both cases, the number of query points is = W*H/subq**2
"""
def __init__(self, ngh, subq=-8, subd=1, pos_d=2, neg_d=4, border=16, subd_neg=-8):
nn.Module.__init__(self)
assert 0 <= pos_d < neg_d <= (ngh if ngh else 99)
self.ngh = ngh
self.pos_d = pos_d
self.neg_d = neg_d
assert subd <= ngh or ngh == 0
assert subq != 0
self.sub_q = subq
self.sub_d = subd
self.sub_d_neg = subd_neg
if border is None: border = ngh
assert border >= ngh, 'border has to be larger than ngh'
self.border = border
self.precompute_offsets()
def precompute_offsets(self):
pos_d2 = self.pos_d**2
neg_d2 = self.neg_d**2
rad2 = self.ngh**2
rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple
pos = []
neg = []
for j in range(-rad, rad+1, self.sub_d):
for i in range(-rad, rad+1, self.sub_d):
d2 = i*i + j*j
if d2 <= pos_d2:
pos.append( (i,j) )
elif neg_d2 <= d2 <= rad2:
neg.append( (i,j) )
self.register_buffer('pos_offsets', torch.LongTensor(pos).view(-1,2).t())
self.register_buffer('neg_offsets', torch.LongTensor(neg).view(-1,2).t())
def gen_grid(self, step, aflow):
B, two, H, W = aflow.shape
dev = aflow.device
b1 = torch.arange(B, device=dev)
if step > 0:
# regular grid
x1 = torch.arange(self.border, W-self.border, step, device=dev)
y1 = torch.arange(self.border, H-self.border, step, device=dev)
H1, W1 = len(y1), len(x1)
shape = (B, H1, W1)
x1 = x1[None,None,:].expand(B,H1,W1).reshape(-1)
y1 = y1[None,:,None].expand(B,H1,W1).reshape(-1)
b1 = b1[:,None,None].expand(B,H1,W1).reshape(-1)
else:
# randomly spread
n = (H - 2*self.border) * (W - 2*self.border) // step**2
x1 = torch.randint(self.border, W-self.border, (n,), device=dev)
y1 = torch.randint(self.border, H-self.border, (n,), device=dev)
x1 = x1[None,:].expand(B,n).reshape(-1)
y1 = y1[None,:].expand(B,n).reshape(-1)
b1 = b1[:,None].expand(B,n).reshape(-1)
shape = (B, n)
return b1, y1, x1, shape
def forward(self, feats, confs, aflow, **kw):
B, two, H, W = aflow.shape
assert two == 2, bb()
feat1, conf1 = feats[0], (confs[0] if confs else None)
feat2, conf2 = feats[1], (confs[1] if confs else None)
# positions in the first image
b_, y1, x1, shape = self.gen_grid(self.sub_q, aflow)
# sample features from first image
feat1 = feat1[b_, :, y1, x1]
qconf = conf1[b_, :, y1, x1].view(shape) if confs else None
#sample GT from second image
xy2 = (aflow[b_, :, y1, x1] + 0.5).long().t()
mask = (0 <= xy2[0]) * (0 <= xy2[1]) * (xy2[0] < W) * (xy2[1] < H)
mask = mask.view(shape)
def clamp(xy):
torch.clamp(xy[0], 0, W-1, out=xy[0])
torch.clamp(xy[1], 0, H-1, out=xy[1])
return xy
# compute positive scores
xy2p = clamp(xy2[:,None,:] + self.pos_offsets[:,:,None])
pscores = torch.einsum('nk,ink->ni', feat1, feat2[b_, :, xy2p[1], xy2p[0]])
# compute negative scores
xy2n = clamp(xy2[:,None,:] + self.neg_offsets[:,:,None])
nscores = torch.einsum('nk,ink->ni', feat1, feat2[b_, :, xy2n[1], xy2n[0]])
if self.sub_d_neg:
# add distractors from a grid
b3, y3, x3 = self.gen_grid(self.sub_d_neg, aflow)[:3]
distractors = feat2[b3, :, y3, x3]
dscores = torch.einsum('nk,ik->ni', feat1, distractors)
del distractors
# remove scores that corresponds to positives or nulls
x2, y2 = xy2 = xy2.float()
xy3 = torch.stack((x3,y3)).float()
dis2 = torch.cdist((xy2+b_*512).T, (xy3+b3*512).T, compute_mode='donot_use_mm_for_euclid_dist')
dscores[dis2 < self.neg_d] = 0
scores = torch.cat((pscores, nscores, dscores), dim=1)
gt = scores.new_zeros(scores.shape, dtype=torch.uint8)
gt[:, :pscores.shape[1]] = 1
return scores, gt, mask, qconf