|
|
|
|
|
|
|
|
|
from pdb import set_trace as bb |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .ap_loss import APLoss |
|
from datasets.utils import applyh |
|
|
|
|
|
class PixelAPLoss (nn.Module): |
|
""" Computes the pixel-wise AP loss: |
|
Given two images and ground-truth optical flow, computes the AP per pixel. |
|
|
|
feat1: (B, C, H, W) pixel-wise features extracted from img1 |
|
feat2: (B, C, H, W) pixel-wise features extracted from img2 |
|
aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2 |
|
""" |
|
def __init__(self, sampler, nq=20, inner_bw=False, bw_step=256): |
|
nn.Module.__init__(self) |
|
self.aploss = APLoss(nq, min=0, max=1, euc=False) |
|
self.name = 'pixAP' |
|
self.sampler = sampler |
|
self.inner_bw = inner_bw |
|
self.bw_step = bw_step |
|
|
|
def loss_from_ap(self, ap, rel): |
|
return 1 - ap |
|
|
|
def forward(self, desc1, desc2, homography, backward_loss=None, **kw): |
|
if len(desc1) == 0: return dict(ap_loss=0) |
|
aflow = aflow_from_H(homography, desc1) |
|
descriptors = (desc1, desc2) |
|
scores, gt, msk, qconf = self.sampler(descriptors, kw.get('reliability'), aflow) |
|
|
|
|
|
n = msk.numel() |
|
if n == 0: return 0 |
|
scores, gt = scores.view(n,-1), gt.view(n,-1) |
|
|
|
backward_loss = backward_loss or self.inner_bw |
|
if self.training and torch.is_grad_enabled() and backward_loss: |
|
|
|
scores_, qconf_ = scores, qconf if qconf is not None else scores.new_ones(msk.shape) |
|
scores = scores.detach().requires_grad_(True) |
|
qconf = qconf_.detach().requires_grad_(True) |
|
msk = msk.ravel() |
|
|
|
loss = 0 |
|
for i in range(0, n, self.bw_step): |
|
sl = slice(i, i+self.bw_step) |
|
ap = self.aploss(scores[sl], gt[sl]) |
|
pixel_loss = self.loss_from_ap(ap, qconf.ravel()[sl] if qconf is not None else None) |
|
l = backward_loss / msk.sum() * pixel_loss[msk[sl]].sum() |
|
loss += float(l) |
|
l.backward() |
|
loss = (loss, [(scores_,scores.grad)]) |
|
if qconf_.requires_grad: loss[1].append((qconf_,qconf.grad)) |
|
|
|
else: |
|
ap = self.aploss(scores, gt).view(msk.shape) |
|
pixel_loss = self.loss_from_ap(ap, qconf) |
|
loss = pixel_loss[msk].mean() |
|
|
|
return dict(ap_loss=loss) |
|
|
|
|
|
def make_grid(B, H, W, device ): |
|
b = torch.arange(B, device=device).view(B,1,1).expand(B,H,W) |
|
y = torch.arange(H, device=device).view(1,H,1).expand(B,H,W) |
|
x = torch.arange(W, device=device).view(1,1,W).expand(B,H,W) |
|
return b.view(B,H*W), torch.stack((x,y),dim=-1).view(B,H*W,2) |
|
|
|
|
|
def aflow_from_H( H_1to2, feat1 ): |
|
B, _, H, W = feat1.shape |
|
b, pos1 = make_grid(B,H,W, feat1.device) |
|
pos2 = applyh(H_1to2, pos1.float()) |
|
return pos2.view(B,H,W,2).permute(0,3,1,2) |
|
|