Spaces:
Running
Running
from einops.einops import rearrange | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from romatch.utils.utils import get_gt_warp | |
import wandb | |
import romatch | |
import math | |
# This is slightly different than regular romatch due to significantly worse corresps | |
# The confidence loss is quite tricky here //Johan | |
class RobustLosses(nn.Module): | |
def __init__( | |
self, | |
robust=False, | |
center_coords=False, | |
scale_normalize=False, | |
ce_weight=0.01, | |
local_loss=True, | |
local_dist=None, | |
smooth_mask = False, | |
depth_interpolation_mode = "bilinear", | |
mask_depth_loss = False, | |
relative_depth_error_threshold = 0.05, | |
alpha = 1., | |
c = 1e-3, | |
epe_mask_prob_th = None, | |
cert_only_on_consistent_depth = False, | |
): | |
super().__init__() | |
if local_dist is None: | |
local_dist = {} | |
self.robust = robust # measured in pixels | |
self.center_coords = center_coords | |
self.scale_normalize = scale_normalize | |
self.ce_weight = ce_weight | |
self.local_loss = local_loss | |
self.local_dist = local_dist | |
self.smooth_mask = smooth_mask | |
self.depth_interpolation_mode = depth_interpolation_mode | |
self.mask_depth_loss = mask_depth_loss | |
self.relative_depth_error_threshold = relative_depth_error_threshold | |
self.avg_overlap = dict() | |
self.alpha = alpha | |
self.c = c | |
self.epe_mask_prob_th = epe_mask_prob_th | |
self.cert_only_on_consistent_depth = cert_only_on_consistent_depth | |
def corr_volume_loss(self, mnn:torch.Tensor, corr_volume:torch.Tensor, scale): | |
b, h,w, h,w = corr_volume.shape | |
inv_temp = 10 | |
corr_volume = corr_volume.reshape(-1, h*w, h*w) | |
nll = -(inv_temp*corr_volume).log_softmax(dim = 1) - (inv_temp*corr_volume).log_softmax(dim = 2) | |
corr_volume_loss = nll[mnn[:,0], mnn[:,1], mnn[:,2]].mean() | |
losses = { | |
f"gm_corr_volume_loss_{scale}": corr_volume_loss.mean(), | |
} | |
wandb.log(losses, step = romatch.GLOBAL_STEP) | |
return losses | |
def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"): | |
epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1) | |
if scale in self.local_dist: | |
prob = prob * (epe < (2 / 512) * (self.local_dist[scale] * scale)).float() | |
if scale == 1: | |
pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean() | |
wandb.log({"train_pck_05": pck_05}, step = romatch.GLOBAL_STEP) | |
if self.epe_mask_prob_th is not None: | |
# if too far away from gt, certainty should be 0 | |
gt_cert = prob * (epe < scale * self.epe_mask_prob_th) | |
else: | |
gt_cert = prob | |
if self.cert_only_on_consistent_depth: | |
ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0][prob > 0], gt_cert[prob > 0]) | |
else: | |
ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], gt_cert) | |
a = self.alpha[scale] if isinstance(self.alpha, dict) else self.alpha | |
cs = self.c * scale | |
x = epe[prob > 0.99] | |
reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2) | |
if not torch.any(reg_loss): | |
reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere | |
losses = { | |
f"{mode}_certainty_loss_{scale}": ce_loss.mean(), | |
f"{mode}_regression_loss_{scale}": reg_loss.mean(), | |
} | |
wandb.log(losses, step = romatch.GLOBAL_STEP) | |
return losses | |
def forward(self, corresps, batch): | |
scales = list(corresps.keys()) | |
tot_loss = 0.0 | |
# scale_weights due to differences in scale for regression gradients and classification gradients | |
for scale in scales: | |
scale_corresps = corresps[scale] | |
scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_corr_volume, scale_gm_certainty, flow, scale_gm_flow = ( | |
scale_corresps["certainty"], | |
scale_corresps.get("flow_pre_delta"), | |
scale_corresps.get("delta_cls"), | |
scale_corresps.get("offset_scale"), | |
scale_corresps.get("corr_volume"), | |
scale_corresps.get("gm_certainty"), | |
scale_corresps["flow"], | |
scale_corresps.get("gm_flow"), | |
) | |
if flow_pre_delta is not None: | |
flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d") | |
b, h, w, d = flow_pre_delta.shape | |
else: | |
# _ = 1 | |
b, _, h, w = scale_certainty.shape | |
gt_warp, gt_prob = get_gt_warp( | |
batch["im_A_depth"], | |
batch["im_B_depth"], | |
batch["T_1to2"], | |
batch["K1"], | |
batch["K2"], | |
H=h, | |
W=w, | |
) | |
x2 = gt_warp.float() | |
prob = gt_prob | |
if scale_gm_corr_volume is not None: | |
gt_warp_back, _ = get_gt_warp( | |
batch["im_B_depth"], | |
batch["im_A_depth"], | |
batch["T_1to2"].inverse(), | |
batch["K2"], | |
batch["K1"], | |
H=h, | |
W=w, | |
) | |
grid = torch.stack(torch.meshgrid(torch.linspace(-1+1/w, 1-1/w, w), torch.linspace(-1+1/h, 1-1/h, h), indexing='xy'), dim =-1).to(gt_warp.device) | |
#fwd_bck = F.grid_sample(gt_warp_back.permute(0,3,1,2), gt_warp, align_corners=False, mode = 'bilinear').permute(0,2,3,1) | |
#diff = (fwd_bck - grid).norm(dim = -1) | |
with torch.no_grad(): | |
D_B = torch.cdist(gt_warp.float().reshape(-1,h*w,2), grid.reshape(-1,h*w,2)) | |
D_A = torch.cdist(grid.reshape(-1,h*w,2), gt_warp_back.float().reshape(-1,h*w,2)) | |
inds = torch.nonzero((D_B == D_B.min(dim=-1, keepdim = True).values) | |
* (D_A == D_A.min(dim=-2, keepdim = True).values) | |
* (D_B < 0.01) | |
* (D_A < 0.01)) | |
gm_cls_losses = self.corr_volume_loss(inds, scale_gm_corr_volume, scale) | |
gm_loss = gm_cls_losses[f"gm_corr_volume_loss_{scale}"] | |
tot_loss = tot_loss + gm_loss | |
elif scale_gm_flow is not None: | |
gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm") | |
gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"] | |
tot_loss = tot_loss + gm_loss | |
delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale) | |
reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"] | |
tot_loss = tot_loss + reg_loss | |
return tot_loss | |