import torch import functools import numpy as np import torch import torch.nn.functional as F from torch import nn import flow_reconstruction import utils from utils.visualisation import flow2rgb_torch logger = utils.log.getLogger(__name__) class ReconstructionLoss: def __init__(self, cfg, model): self.criterion = nn.MSELoss() if cfg.GWM.CRITERION == 'L2' else nn.L1Loss() self.l1_optimize = cfg.GWM.L1_OPTIMIZE self.homography = cfg.GWM.HOMOGRAPHY self.device=model.device self.cfg = cfg self.grid_x, self.grid_y = utils.grid.get_meshgrid(cfg.GWM.RESOLUTION, model.device) # self.mult_flow = cfg.GWM.USE_MULT_FLOW self.flow_colorspace_rec = cfg.GWM.FLOW_COLORSPACE_REC flow_reconstruction.set_subsample_skip(cfg.GWM.HOMOGRAPHY_SUBSAMPLE, cfg.GWM.HOMOGRAPHY_SKIP) self.flow_u_low = cfg.GWM.FLOW_CLIP_U_LOW self.flow_u_high = cfg.GWM.FLOW_CLIP_U_HIGH self.flow_v_low = cfg.GWM.FLOW_CLIP_V_LOW self.flow_v_high = cfg.GWM.FLOW_CLIP_V_HIGH self._recon_fn = self.flow_quad logger.info(f'Using reconstruction method {self._recon_fn.__name__}') self.it = 0 self._extra_losses = [] def __call__(self, sample, flow, masks_softmaxed, it, train=True): return self.loss(sample, flow, masks_softmaxed, it, train=train) def loss(self, sample, flow, mask_softmaxed, it, train=True): self.training = train flow = self.process_flow(sample, flow) self.it = it self._extra_losses = [] if self.cfg.GWM.FLOW_RES is not None: if flow.shape[-2:] != mask_softmaxed.shape[-2:]: logger.debug_once(f'Resizing predicted masks to {self.cfg.GWM.FLOW_RES}') mask_softmaxed = F.interpolate(mask_softmaxed, flow.shape[-2:], mode='bilinear', align_corners=False) rec_flows = self.rec_flow(sample, flow, mask_softmaxed) if not isinstance(rec_flows, (list, tuple)): rec_flows = (rec_flows,) k = len(rec_flows) loss = sum(self.criterion(flow, rec_flow) / k for rec_flow in rec_flows) if len(self._extra_losses): loss = loss + sum(self._extra_losses, 0.) / len(self._extra_losses) self._extra_losses = [] return loss def flow_quad(self, sample, flow, masks_softmaxed, it, **_): logger.debug_once(f'Reconstruction using quadratic. Masks shape {masks_softmaxed.shape} | ' f'Flow shape {flow.shape} | ' f'Grid shape {self.grid_x.shape, self.grid_y.shape}') return flow_reconstruction.get_quad_flow(masks_softmaxed, flow, self.grid_x, self.grid_y) def _clipped_recon_fn(self, *args, **kwargs): flow = self._recon_fn(*args, **kwargs) flow_o = flow[:, :-2] flow_u = flow[:, -2:-1].clip(self.flow_u_low, self.flow_u_high) flow_v = flow[:, -1:].clip(self.flow_v_low, self.flow_v_high) return torch.cat([flow_o, flow_u, flow_v], dim=1) def rec_flow(self, sample, flow, masks_softmaxed): it = self.it if self.cfg.GWM.FLOW_RES is not None and flow.shape[-2:] != self.grid_x.shape[-2:]: logger.debug_once(f'Generating new grid predicted masks of {flow.shape[-2:]}') self.grid_x, self.grid_y = utils.grid.get_meshgrid(flow.shape[-2:], self.device) return [self._clipped_recon_fn(sample, flow, masks_softmaxed, it)] def process_flow(self, sample, flow_cuda): return flow_cuda def viz_flow(self, flow): return torch.stack([flow2rgb_torch(x) for x in flow])