image-matching-webui / third_party /DKM /dkm /losses /depth_match_regression_loss.py
Vincentqyw
add: roma
c608946
raw
history blame
No virus
4.47 kB
from einops.einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from dkm.utils.utils import warp_kpts
class DepthRegressionLoss(nn.Module):
def __init__(
self,
robust=True,
center_coords=False,
scale_normalize=False,
ce_weight=0.01,
local_loss=True,
local_dist=4.0,
local_largest_scale=8,
):
super().__init__()
self.robust = robust # measured in pixels
self.center_coords = center_coords
self.scale_normalize = scale_normalize
self.ce_weight = ce_weight
self.local_loss = local_loss
self.local_dist = local_dist
self.local_largest_scale = local_largest_scale
def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches, scale):
"""[summary]
Args:
H ([type]): [description]
scale ([type]): [description]
Returns:
[type]: [description]
"""
b, h1, w1, d = dense_matches.shape
with torch.no_grad():
x1_n = torch.meshgrid(
*[
torch.linspace(
-1 + 1 / n, 1 - 1 / n, n, device=dense_matches.device
)
for n in (b, h1, w1)
]
)
x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(b, h1 * w1, 2)
mask, x2 = warp_kpts(
x1_n.double(),
depth1.double(),
depth2.double(),
T_1to2.double(),
K1.double(),
K2.double(),
)
prob = mask.float().reshape(b, h1, w1)
gd = (dense_matches - x2.reshape(b, h1, w1, 2)).norm(dim=-1) # *scale?
return gd, prob
def dense_depth_loss(self, dense_certainty, prob, gd, scale, eps=1e-8):
"""[summary]
Args:
dense_certainty ([type]): [description]
prob ([type]): [description]
eps ([type], optional): [description]. Defaults to 1e-8.
Returns:
[type]: [description]
"""
smooth_prob = prob
ce_loss = F.binary_cross_entropy_with_logits(dense_certainty[:, 0], smooth_prob)
depth_loss = gd[prob > 0]
if not torch.any(prob > 0).item():
depth_loss = (gd * 0.0).mean() # Prevent issues where prob is 0 everywhere
return {
f"ce_loss_{scale}": ce_loss.mean(),
f"depth_loss_{scale}": depth_loss.mean(),
}
def forward(self, dense_corresps, batch):
"""[summary]
Args:
out ([type]): [description]
batch ([type]): [description]
Returns:
[type]: [description]
"""
scales = list(dense_corresps.keys())
tot_loss = 0.0
prev_gd = 0.0
for scale in scales:
dense_scale_corresps = dense_corresps[scale]
dense_scale_certainty, dense_scale_coords = (
dense_scale_corresps["dense_certainty"],
dense_scale_corresps["dense_flow"],
)
dense_scale_coords = rearrange(dense_scale_coords, "b d h w -> b h w d")
b, h, w, d = dense_scale_coords.shape
gd, prob = self.geometric_dist(
batch["query_depth"],
batch["support_depth"],
batch["T_1to2"],
batch["K1"],
batch["K2"],
dense_scale_coords,
scale,
)
if (
scale <= self.local_largest_scale and self.local_loss
): # Thought here is that fine matching loss should not be punished by coarse mistakes, but should identify wrong matching
prob = prob * (
F.interpolate(prev_gd[:, None], size=(h, w), mode="nearest")[:, 0]
< (2 / 512) * (self.local_dist * scale)
)
depth_losses = self.dense_depth_loss(dense_scale_certainty, prob, gd, scale)
scale_loss = (
self.ce_weight * depth_losses[f"ce_loss_{scale}"]
+ depth_losses[f"depth_loss_{scale}"]
) # scale ce loss for coarser scales
if self.scale_normalize:
scale_loss = scale_loss * 1 / scale
tot_loss = tot_loss + scale_loss
prev_gd = gd.detach()
return tot_loss