PUMP / core /losses /multiloss.py
Philippe Weinzaepfel
huggingface demo
3ef85e9
raw
history blame contribute delete
No virus
2.27 kB
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use
from pdb import set_trace as bb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tools.trainer import backward
class MultiLoss (nn.Module):
""" This functions handles both supervised and unsupervised samples.
"""
def __init__(self, loss_sup, loss_unsup, alpha=0.3, inner_bw=True):
super().__init__()
assert 0 <= alpha
self.alpha_sup = 1 # coef of self-supervised loss
self.loss_sup = loss_sup
self.alpha_unsup = alpha # coef of unsupervised loss
self.loss_unsup = loss_unsup
self.inner_bw = inner_bw
def forward(self, desc1, desc2, homography, **kw):
sl_sup, sl_unsup = split_batch_sup_unsup(homography, 512 if self.inner_bw else 8)
inner_bw = self.inner_bw and self.training and torch.is_grad_enabled()
if inner_bw: (desc1, desc1_), (desc2, desc2_) = pause_gradient((desc1,desc2))
kw['desc1'], kw['desc2'], kw['homography'] = desc1, desc2, homography
(sup_name, sup_loss) ,= self.loss_sup(backward_loss=inner_bw*self.alpha_sup, **{k:v[sl_sup] for k,v in kw.items()}).items()
if inner_bw and sup_loss: sup_loss = backward(sup_loss) # backward to desc1 and desc2
(uns_name, uns_loss) ,= self.loss_unsup(**{k:v[sl_unsup] for k,v in kw.items()}).items()
uns_loss = self.alpha_unsup * uns_loss
if inner_bw and uns_loss: uns_loss = backward(uns_loss) # backward to desc1 and desc2
loss = sup_loss + uns_loss
return {'loss':(loss, [(desc1_,desc1.grad),(desc2_,desc2.grad)]), sup_name:float(sup_loss), uns_name:float(uns_loss)}
def pause_gradient( objs ):
return [(obj.detach().requires_grad_(True), obj) for obj in objs]
def split_batch_sup_unsup(homography, max_sup=512):
# split batch in supervised / unsupervised
i = int(torch.isfinite(homography[:,0,0]).sum()) # first ocurence
sl_sup, sl_unsup = slice(0, min(i,max_sup)), slice(i, None)
assert torch.isfinite(homography[sl_sup]).all(), 'batch is not properly sorted!'
assert torch.isnan(homography[sl_unsup]).all(), 'batch is not properly sorted!'
return sl_sup, sl_unsup