import kornia import torch from .utils import get_image_coords from .wrappers import Camera def sample_fmap(pts, fmap): h, w = fmap.shape[-2:] grid_sample = torch.nn.functional.grid_sample pts = (pts / pts.new_tensor([[w, h]]) * 2 - 1)[:, None] # @TODO: This might still be a source of noise --> bilinear interpolation dangerous interp_lin = grid_sample(fmap, pts, align_corners=False, mode="bilinear") interp_nn = grid_sample(fmap, pts, align_corners=False, mode="nearest") return torch.where(torch.isnan(interp_lin), interp_nn, interp_lin)[:, :, 0].permute( 0, 2, 1 ) def sample_depth(pts, depth_): depth = torch.where(depth_ > 0, depth_, depth_.new_tensor(float("nan"))) depth = depth[:, None] interp = sample_fmap(pts, depth).squeeze(-1) valid = (~torch.isnan(interp)) & (interp > 0) return interp, valid def sample_normals_from_depth(pts, depth, K): depth = depth[:, None] normals = kornia.geometry.depth.depth_to_normals(depth, K) normals = torch.where(depth > 0, normals, 0.0) interp = sample_fmap(pts, normals) valid = (~torch.isnan(interp)) & (interp > 0) return interp, valid def project( kpi, di, depthj, camera_i, camera_j, T_itoj, validi, ccth=None, sample_depth_fun=sample_depth, sample_depth_kwargs=None, ): if sample_depth_kwargs is None: sample_depth_kwargs = {} kpi_3d_i = camera_i.image2cam(kpi) kpi_3d_i = kpi_3d_i * di[..., None] kpi_3d_j = T_itoj.transform(kpi_3d_i) kpi_j, validj = camera_j.cam2image(kpi_3d_j) # di_j = kpi_3d_j[..., -1] validi = validi & validj if depthj is None or ccth is None: return kpi_j, validi & validj else: # circle consistency dj, validj = sample_depth_fun(kpi_j, depthj, **sample_depth_kwargs) kpi_j_3d_j = camera_j.image2cam(kpi_j) * dj[..., None] kpi_j_i, validj_i = camera_i.cam2image(T_itoj.inv().transform(kpi_j_3d_j)) consistent = ((kpi - kpi_j_i) ** 2).sum(-1) < ccth visible = validi & consistent & validj_i & validj # visible = validi return kpi_j, visible def dense_warp_consistency( depthi: torch.Tensor, depthj: torch.Tensor, T_itoj: torch.Tensor, camerai: Camera, cameraj: Camera, **kwargs, ): kpi = get_image_coords(depthi).flatten(-3, -2) di = depthi.flatten( -2, ) validi = di > 0 kpir, validir = project(kpi, di, depthj, camerai, cameraj, T_itoj, validi, **kwargs) return kpir.unflatten(-2, depthi.shape[-2:]), validir.unflatten( -1, (depthj.shape[-2:]) )