Guess-What-Moves / losses /__init__.py
subhc's picture
Code Commit
5e88f62
from .reconstruction_loss import ReconstructionLoss
import torch
class CriterionDict:
def __init__(self, dict):
self.criterions = dict
def __call__(self, sample, flow, masks_softmaxed, iteration, train=True, prefix=''):
loss = torch.tensor(0., device=masks_softmaxed.device, dtype=masks_softmaxed.dtype)
log_dict = {}
for name_i, (criterion_i, loss_multiplier_i, anneal_fn_i) in self.criterions.items():
loss_i = loss_multiplier_i * anneal_fn_i(iteration) * criterion_i(sample, flow, masks_softmaxed, iteration, train=train)
loss += loss_i
log_dict[f'loss_{name_i}'] = loss_i.item()
log_dict['loss_total'] = loss.item()
return loss, log_dict
def flow_reconstruction(self, sample, flow, masks_softmaxed):
return self.criterions['reconstruction'][0].rec_flow(sample, flow, masks_softmaxed)
def process_flow(self, sample, flow):
return self.criterions['reconstruction'][0].process_flow(sample, flow)
def viz_flow(self, flow):
return self.criterions['reconstruction'][0].viz_flow(flow)