Guess-What-Moves / losses /reconstruction_loss.py
subhc's picture
Code Commit
5e88f62
import torch
import functools
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
import flow_reconstruction
import utils
from utils.visualisation import flow2rgb_torch
logger = utils.log.getLogger(__name__)
class ReconstructionLoss:
def __init__(self, cfg, model):
self.criterion = nn.MSELoss() if cfg.GWM.CRITERION == 'L2' else nn.L1Loss()
self.l1_optimize = cfg.GWM.L1_OPTIMIZE
self.homography = cfg.GWM.HOMOGRAPHY
self.device=model.device
self.cfg = cfg
self.grid_x, self.grid_y = utils.grid.get_meshgrid(cfg.GWM.RESOLUTION, model.device)
# self.mult_flow = cfg.GWM.USE_MULT_FLOW
self.flow_colorspace_rec = cfg.GWM.FLOW_COLORSPACE_REC
flow_reconstruction.set_subsample_skip(cfg.GWM.HOMOGRAPHY_SUBSAMPLE, cfg.GWM.HOMOGRAPHY_SKIP)
self.flow_u_low = cfg.GWM.FLOW_CLIP_U_LOW
self.flow_u_high = cfg.GWM.FLOW_CLIP_U_HIGH
self.flow_v_low = cfg.GWM.FLOW_CLIP_V_LOW
self.flow_v_high = cfg.GWM.FLOW_CLIP_V_HIGH
self._recon_fn = self.flow_quad
logger.info(f'Using reconstruction method {self._recon_fn.__name__}')
self.it = 0
self._extra_losses = []
def __call__(self, sample, flow, masks_softmaxed, it, train=True):
return self.loss(sample, flow, masks_softmaxed, it, train=train)
def loss(self, sample, flow, mask_softmaxed, it, train=True):
self.training = train
flow = self.process_flow(sample, flow)
self.it = it
self._extra_losses = []
if self.cfg.GWM.FLOW_RES is not None:
if flow.shape[-2:] != mask_softmaxed.shape[-2:]:
logger.debug_once(f'Resizing predicted masks to {self.cfg.GWM.FLOW_RES}')
mask_softmaxed = F.interpolate(mask_softmaxed, flow.shape[-2:], mode='bilinear', align_corners=False)
rec_flows = self.rec_flow(sample, flow, mask_softmaxed)
if not isinstance(rec_flows, (list, tuple)):
rec_flows = (rec_flows,)
k = len(rec_flows)
loss = sum(self.criterion(flow, rec_flow) / k for rec_flow in rec_flows)
if len(self._extra_losses):
loss = loss + sum(self._extra_losses, 0.) / len(self._extra_losses)
self._extra_losses = []
return loss
def flow_quad(self, sample, flow, masks_softmaxed, it, **_):
logger.debug_once(f'Reconstruction using quadratic. Masks shape {masks_softmaxed.shape} | '
f'Flow shape {flow.shape} | '
f'Grid shape {self.grid_x.shape, self.grid_y.shape}')
return flow_reconstruction.get_quad_flow(masks_softmaxed, flow, self.grid_x, self.grid_y)
def _clipped_recon_fn(self, *args, **kwargs):
flow = self._recon_fn(*args, **kwargs)
flow_o = flow[:, :-2]
flow_u = flow[:, -2:-1].clip(self.flow_u_low, self.flow_u_high)
flow_v = flow[:, -1:].clip(self.flow_v_low, self.flow_v_high)
return torch.cat([flow_o, flow_u, flow_v], dim=1)
def rec_flow(self, sample, flow, masks_softmaxed):
it = self.it
if self.cfg.GWM.FLOW_RES is not None and flow.shape[-2:] != self.grid_x.shape[-2:]:
logger.debug_once(f'Generating new grid predicted masks of {flow.shape[-2:]}')
self.grid_x, self.grid_y = utils.grid.get_meshgrid(flow.shape[-2:], self.device)
return [self._clipped_recon_fn(sample, flow, masks_softmaxed, it)]
def process_flow(self, sample, flow_cuda):
return flow_cuda
def viz_flow(self, flow):
return torch.stack([flow2rgb_torch(x) for x in flow])