File size: 1,118 Bytes
5e88f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from .reconstruction_loss import ReconstructionLoss
import torch


class CriterionDict:
    def __init__(self, dict):
        self.criterions = dict

    def __call__(self, sample, flow, masks_softmaxed, iteration, train=True, prefix=''):
        loss = torch.tensor(0., device=masks_softmaxed.device, dtype=masks_softmaxed.dtype)
        log_dict = {}
        for name_i, (criterion_i, loss_multiplier_i, anneal_fn_i) in self.criterions.items():
            loss_i = loss_multiplier_i * anneal_fn_i(iteration) * criterion_i(sample, flow, masks_softmaxed, iteration, train=train)
            loss += loss_i
            log_dict[f'loss_{name_i}'] = loss_i.item()

        log_dict['loss_total'] = loss.item()
        return loss, log_dict

    def flow_reconstruction(self, sample, flow, masks_softmaxed):
        return self.criterions['reconstruction'][0].rec_flow(sample, flow, masks_softmaxed)

    def process_flow(self, sample, flow):
        return self.criterions['reconstruction'][0].process_flow(sample, flow)

    def viz_flow(self, flow):
        return self.criterions['reconstruction'][0].viz_flow(flow)