Spaces:
Running
Running
from einops.einops import rearrange | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from roma.utils.utils import get_gt_warp | |
import wandb | |
import roma | |
import math | |
class RobustLosses(nn.Module): | |
def __init__( | |
self, | |
robust=False, | |
center_coords=False, | |
scale_normalize=False, | |
ce_weight=0.01, | |
local_loss=True, | |
local_dist=4.0, | |
local_largest_scale=8, | |
smooth_mask=False, | |
depth_interpolation_mode="bilinear", | |
mask_depth_loss=False, | |
relative_depth_error_threshold=0.05, | |
alpha=1.0, | |
c=1e-3, | |
): | |
super().__init__() | |
self.robust = robust # measured in pixels | |
self.center_coords = center_coords | |
self.scale_normalize = scale_normalize | |
self.ce_weight = ce_weight | |
self.local_loss = local_loss | |
self.local_dist = local_dist | |
self.local_largest_scale = local_largest_scale | |
self.smooth_mask = smooth_mask | |
self.depth_interpolation_mode = depth_interpolation_mode | |
self.mask_depth_loss = mask_depth_loss | |
self.relative_depth_error_threshold = relative_depth_error_threshold | |
self.avg_overlap = dict() | |
self.alpha = alpha | |
self.c = c | |
def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale): | |
with torch.no_grad(): | |
B, C, H, W = scale_gm_cls.shape | |
device = x2.device | |
cls_res = round(math.sqrt(C)) | |
G = torch.meshgrid( | |
*[ | |
torch.linspace( | |
-1 + 1 / cls_res, 1 - 1 / cls_res, steps=cls_res, device=device | |
) | |
for _ in range(2) | |
] | |
) | |
G = torch.stack((G[1], G[0]), dim=-1).reshape(C, 2) | |
GT = ( | |
(G[None, :, None, None, :] - x2[:, None]) | |
.norm(dim=-1) | |
.min(dim=1) | |
.indices | |
) | |
cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction="none")[prob > 0.99] | |
if not torch.any(cls_loss): | |
cls_loss = certainty_loss * 0.0 # Prevent issues where prob is 0 everywhere | |
certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:, 0], prob) | |
losses = { | |
f"gm_certainty_loss_{scale}": certainty_loss.mean(), | |
f"gm_cls_loss_{scale}": cls_loss.mean(), | |
} | |
wandb.log(losses, step=roma.GLOBAL_STEP) | |
return losses | |
def delta_cls_loss( | |
self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale | |
): | |
with torch.no_grad(): | |
B, C, H, W = delta_cls.shape | |
device = x2.device | |
cls_res = round(math.sqrt(C)) | |
G = torch.meshgrid( | |
*[ | |
torch.linspace( | |
-1 + 1 / cls_res, 1 - 1 / cls_res, steps=cls_res, device=device | |
) | |
for _ in range(2) | |
] | |
) | |
G = torch.stack((G[1], G[0]), dim=-1).reshape(C, 2) * offset_scale | |
GT = ( | |
(G[None, :, None, None, :] + flow_pre_delta[:, None] - x2[:, None]) | |
.norm(dim=-1) | |
.min(dim=1) | |
.indices | |
) | |
cls_loss = F.cross_entropy(delta_cls, GT, reduction="none")[prob > 0.99] | |
if not torch.any(cls_loss): | |
cls_loss = certainty_loss * 0.0 # Prevent issues where prob is 0 everywhere | |
certainty_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob) | |
losses = { | |
f"delta_certainty_loss_{scale}": certainty_loss.mean(), | |
f"delta_cls_loss_{scale}": cls_loss.mean(), | |
} | |
wandb.log(losses, step=roma.GLOBAL_STEP) | |
return losses | |
def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode="delta"): | |
epe = (flow.permute(0, 2, 3, 1) - x2).norm(dim=-1) | |
if scale == 1: | |
pck_05 = (epe[prob > 0.99] < 0.5 * (2 / 512)).float().mean() | |
wandb.log({"train_pck_05": pck_05}, step=roma.GLOBAL_STEP) | |
ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob) | |
a = self.alpha | |
cs = self.c * scale | |
x = epe[prob > 0.99] | |
reg_loss = cs**a * ((x / (cs)) ** 2 + 1**2) ** (a / 2) | |
if not torch.any(reg_loss): | |
reg_loss = ce_loss * 0.0 # Prevent issues where prob is 0 everywhere | |
losses = { | |
f"{mode}_certainty_loss_{scale}": ce_loss.mean(), | |
f"{mode}_regression_loss_{scale}": reg_loss.mean(), | |
} | |
wandb.log(losses, step=roma.GLOBAL_STEP) | |
return losses | |
def forward(self, corresps, batch): | |
scales = list(corresps.keys()) | |
tot_loss = 0.0 | |
# scale_weights due to differences in scale for regression gradients and classification gradients | |
scale_weights = {1: 1, 2: 1, 4: 1, 8: 1, 16: 1} | |
for scale in scales: | |
scale_corresps = corresps[scale] | |
( | |
scale_certainty, | |
flow_pre_delta, | |
delta_cls, | |
offset_scale, | |
scale_gm_cls, | |
scale_gm_certainty, | |
flow, | |
scale_gm_flow, | |
) = ( | |
scale_corresps["certainty"], | |
scale_corresps["flow_pre_delta"], | |
scale_corresps.get("delta_cls"), | |
scale_corresps.get("offset_scale"), | |
scale_corresps.get("gm_cls"), | |
scale_corresps.get("gm_certainty"), | |
scale_corresps["flow"], | |
scale_corresps.get("gm_flow"), | |
) | |
flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d") | |
b, h, w, d = flow_pre_delta.shape | |
gt_warp, gt_prob = get_gt_warp( | |
batch["im_A_depth"], | |
batch["im_B_depth"], | |
batch["T_1to2"], | |
batch["K1"], | |
batch["K2"], | |
H=h, | |
W=w, | |
) | |
x2 = gt_warp.float() | |
prob = gt_prob | |
if self.local_largest_scale >= scale: | |
prob = prob * ( | |
F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[ | |
:, 0 | |
] | |
< (2 / 512) * (self.local_dist[scale] * scale) | |
) | |
if scale_gm_cls is not None: | |
gm_cls_losses = self.gm_cls_loss( | |
x2, prob, scale_gm_cls, scale_gm_certainty, scale | |
) | |
gm_loss = ( | |
self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] | |
+ gm_cls_losses[f"gm_cls_loss_{scale}"] | |
) | |
tot_loss = tot_loss + scale_weights[scale] * gm_loss | |
elif scale_gm_flow is not None: | |
gm_flow_losses = self.regression_loss( | |
x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode="gm" | |
) | |
gm_loss = ( | |
self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] | |
+ gm_flow_losses[f"gm_regression_loss_{scale}"] | |
) | |
tot_loss = tot_loss + scale_weights[scale] * gm_loss | |
if delta_cls is not None: | |
delta_cls_losses = self.delta_cls_loss( | |
x2, | |
prob, | |
flow_pre_delta, | |
delta_cls, | |
scale_certainty, | |
scale, | |
offset_scale, | |
) | |
delta_cls_loss = ( | |
self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] | |
+ delta_cls_losses[f"delta_cls_loss_{scale}"] | |
) | |
tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss | |
else: | |
delta_regression_losses = self.regression_loss( | |
x2, prob, flow, scale_certainty, scale | |
) | |
reg_loss = ( | |
self.ce_weight | |
* delta_regression_losses[f"delta_certainty_loss_{scale}"] | |
+ delta_regression_losses[f"delta_regression_loss_{scale}"] | |
) | |
tot_loss = tot_loss + scale_weights[scale] * reg_loss | |
prev_epe = (flow.permute(0, 2, 3, 1) - x2).norm(dim=-1).detach() | |
return tot_loss | |