File size: 2,266 Bytes
3ef85e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use

from pdb import set_trace as bb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from tools.trainer import backward


class MultiLoss (nn.Module):
    """ This functions handles both supervised and unsupervised samples.
    """
    def __init__(self, loss_sup, loss_unsup, alpha=0.3, inner_bw=True):
        super().__init__()
        assert 0 <= alpha
        self.alpha_sup = 1 # coef of self-supervised loss
        self.loss_sup = loss_sup

        self.alpha_unsup = alpha # coef of unsupervised loss
        self.loss_unsup = loss_unsup

        self.inner_bw = inner_bw

    def forward(self, desc1, desc2, homography, **kw):
        sl_sup, sl_unsup = split_batch_sup_unsup(homography, 512 if self.inner_bw else 8)

        inner_bw = self.inner_bw and self.training and torch.is_grad_enabled()
        if inner_bw: (desc1, desc1_), (desc2, desc2_) = pause_gradient((desc1,desc2))
        kw['desc1'], kw['desc2'], kw['homography'] = desc1, desc2, homography

        (sup_name, sup_loss) ,= self.loss_sup(backward_loss=inner_bw*self.alpha_sup, **{k:v[sl_sup] for k,v in kw.items()}).items()
        if inner_bw and sup_loss: sup_loss = backward(sup_loss) # backward to desc1 and desc2

        (uns_name, uns_loss) ,= self.loss_unsup(**{k:v[sl_unsup] for k,v in kw.items()}).items()
        uns_loss = self.alpha_unsup * uns_loss
        if inner_bw and uns_loss: uns_loss = backward(uns_loss) # backward to desc1 and desc2

        loss = sup_loss + uns_loss
        return {'loss':(loss, [(desc1_,desc1.grad),(desc2_,desc2.grad)]), sup_name:float(sup_loss), uns_name:float(uns_loss)}


def pause_gradient( objs ):
    return [(obj.detach().requires_grad_(True), obj) for obj in objs]
    

def split_batch_sup_unsup(homography, max_sup=512):
    # split batch in supervised / unsupervised
    i = int(torch.isfinite(homography[:,0,0]).sum()) # first ocurence
    sl_sup, sl_unsup = slice(0, min(i,max_sup)), slice(i, None)

    assert torch.isfinite(homography[sl_sup]).all(), 'batch is not properly sorted!'
    assert torch.isnan(homography[sl_unsup]).all(), 'batch is not properly sorted!'
    return sl_sup, sl_unsup