PUMP / core /losses /pixel_ap_loss.py
Philippe Weinzaepfel
huggingface demo
3ef85e9
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use
from pdb import set_trace as bb
import torch
import torch.nn as nn
import torch.nn.functional as F
from .ap_loss import APLoss
from datasets.utils import applyh
class PixelAPLoss (nn.Module):
""" Computes the pixel-wise AP loss:
Given two images and ground-truth optical flow, computes the AP per pixel.
feat1: (B, C, H, W) pixel-wise features extracted from img1
feat2: (B, C, H, W) pixel-wise features extracted from img2
aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2
"""
def __init__(self, sampler, nq=20, inner_bw=False, bw_step=256):
nn.Module.__init__(self)
self.aploss = APLoss(nq, min=0, max=1, euc=False)
self.name = 'pixAP'
self.sampler = sampler
self.inner_bw = inner_bw
self.bw_step = bw_step
def loss_from_ap(self, ap, rel):
return 1 - ap
def forward(self, desc1, desc2, homography, backward_loss=None, **kw):
if len(desc1) == 0: return dict(ap_loss=0)
aflow = aflow_from_H(homography, desc1)
descriptors = (desc1, desc2)
scores, gt, msk, qconf = self.sampler(descriptors, kw.get('reliability'), aflow)
# compute pixel-wise AP
n = msk.numel()
if n == 0: return 0
scores, gt = scores.view(n,-1), gt.view(n,-1)
backward_loss = backward_loss or self.inner_bw
if self.training and torch.is_grad_enabled() and backward_loss:
# progressive loss computation and backward, low memory but slow
scores_, qconf_ = scores, qconf if qconf is not None else scores.new_ones(msk.shape)
scores = scores.detach().requires_grad_(True)
qconf = qconf_.detach().requires_grad_(True)
msk = msk.ravel()
loss = 0
for i in range(0, n, self.bw_step):
sl = slice(i, i+self.bw_step)
ap = self.aploss(scores[sl], gt[sl])
pixel_loss = self.loss_from_ap(ap, qconf.ravel()[sl] if qconf is not None else None)
l = backward_loss / msk.sum() * pixel_loss[msk[sl]].sum()
loss += float(l)
l.backward() # cumulate gradient
loss = (loss, [(scores_,scores.grad)])
if qconf_.requires_grad: loss[1].append((qconf_,qconf.grad))
else:
ap = self.aploss(scores, gt).view(msk.shape)
pixel_loss = self.loss_from_ap(ap, qconf)
loss = pixel_loss[msk].mean()
return dict(ap_loss=loss)
def make_grid(B, H, W, device ):
b = torch.arange(B, device=device).view(B,1,1).expand(B,H,W)
y = torch.arange(H, device=device).view(1,H,1).expand(B,H,W)
x = torch.arange(W, device=device).view(1,1,W).expand(B,H,W)
return b.view(B,H*W), torch.stack((x,y),dim=-1).view(B,H*W,2)
def aflow_from_H( H_1to2, feat1 ):
B, _, H, W = feat1.shape
b, pos1 = make_grid(B,H,W, feat1.device)
pos2 = applyh(H_1to2, pos1.float())
return pos2.view(B,H,W,2).permute(0,3,1,2)