import torch from torch_scatter import scatter_sum from . import fastba from . import lietorch from .lietorch import SE3 from .utils import Timer from . import projective_ops as pops class CholeskySolver(torch.autograd.Function): @staticmethod def forward(ctx, H, b): # don't crash training if cholesky decomp fails U, info = torch.linalg.cholesky_ex(H) if torch.any(info): ctx.failed = True return torch.zeros_like(b) xs = torch.cholesky_solve(b, U) ctx.save_for_backward(U, xs) ctx.failed = False return xs @staticmethod def backward(ctx, grad_x): if ctx.failed: return None, None U, xs = ctx.saved_tensors dz = torch.cholesky_solve(grad_x, U) dH = -torch.matmul(xs, dz.transpose(-1,-2)) return dH, dz # utility functions for scattering ops def safe_scatter_add_mat(A, ii, jj, n, m): v = (ii >= 0) & (jj >= 0) & (ii < n) & (jj < m) return scatter_sum(A[:,v], ii[v]*m + jj[v], dim=1, dim_size=n*m) def safe_scatter_add_vec(b, ii, n): v = (ii >= 0) & (ii < n) return scatter_sum(b[:,v], ii[v], dim=1, dim_size=n) # apply retraction operator to inv-depth maps def disp_retr(disps, dz, ii): ii = ii.to(device=dz.device) return disps + scatter_sum(dz, ii, dim=1, dim_size=disps.shape[1]) # apply retraction operator to poses def pose_retr(poses, dx, ii): ii = ii.to(device=dx.device) return poses.retr(scatter_sum(dx, ii, dim=1, dim_size=poses.shape[1])) def block_matmul(A, B): """ block matrix multiply """ b, n1, m1, p1, q1 = A.shape b, n2, m2, p2, q2 = B.shape A = A.permute(0, 1, 3, 2, 4).reshape(b, n1*p1, m1*q1) B = B.permute(0, 1, 3, 2, 4).reshape(b, n2*p2, m2*q2) return torch.matmul(A, B).reshape(b, n1, p1, m2, q2).permute(0, 1, 3, 2, 4) def block_solve(A, B, ep=1.0, lm=1e-4): """ block matrix solve """ b, n1, m1, p1, q1 = A.shape b, n2, m2, p2, q2 = B.shape A = A.permute(0, 1, 3, 2, 4).reshape(b, n1*p1, m1*q1) B = B.permute(0, 1, 3, 2, 4).reshape(b, n2*p2, m2*q2) A = A + (ep + lm * A) * torch.eye(n1*p1, device=A.device) X = CholeskySolver.apply(A, B) return X.reshape(b, n1, p1, m2, q2).permute(0, 1, 3, 2, 4) def block_show(A): import matplotlib.pyplot as plt b, n1, m1, p1, q1 = A.shape A = A.permute(0, 1, 3, 2, 4).reshape(b, n1*p1, m1*q1) plt.imshow(A[0].detach().cpu().numpy()) plt.show() def BA(poses, patches, intrinsics, targets, weights, lmbda, ii, jj, kk, bounds, ep=100.0, PRINT=False, fixedp=1, structure_only=False): """ bundle adjustment """ b = 1 n = max(ii.max().item(), jj.max().item()) + 1 coords, v, (Ji, Jj, Jz) = \ pops.transform(poses, patches, intrinsics, ii, jj, kk, jacobian=True) p = coords.shape[3] r = targets - coords[...,p//2,p//2,:] v *= (r.norm(dim=-1) < 250).float() in_bounds = \ (coords[...,p//2,p//2,0] > bounds[0]) & \ (coords[...,p//2,p//2,1] > bounds[1]) & \ (coords[...,p//2,p//2,0] < bounds[2]) & \ (coords[...,p//2,p//2,1] < bounds[3]) v *= in_bounds.float() if PRINT: print((r * v[...,None]).norm(dim=-1).mean().item()) r = (v[...,None] * r).unsqueeze(dim=-1) weights = (v[...,None] * weights).unsqueeze(dim=-1) wJiT = (weights * Ji).transpose(2,3) wJjT = (weights * Jj).transpose(2,3) wJzT = (weights * Jz).transpose(2,3) Bii = torch.matmul(wJiT, Ji) Bij = torch.matmul(wJiT, Jj) Bji = torch.matmul(wJjT, Ji) Bjj = torch.matmul(wJjT, Jj) Eik = torch.matmul(wJiT, Jz) Ejk = torch.matmul(wJjT, Jz) vi = torch.matmul(wJiT, r) vj = torch.matmul(wJjT, r) # fix first pose ii = ii.clone() jj = jj.clone() n = n - fixedp ii = ii - fixedp jj = jj - fixedp kx, kk = torch.unique(kk, return_inverse=True, sorted=True) m = len(kx) B = safe_scatter_add_mat(Bii, ii, ii, n, n).view(b, n, n, 6, 6) + \ safe_scatter_add_mat(Bij, ii, jj, n, n).view(b, n, n, 6, 6) + \ safe_scatter_add_mat(Bji, jj, ii, n, n).view(b, n, n, 6, 6) + \ safe_scatter_add_mat(Bjj, jj, jj, n, n).view(b, n, n, 6, 6) E = safe_scatter_add_mat(Eik, ii, kk, n, m).view(b, n, m, 6, 1) + \ safe_scatter_add_mat(Ejk, jj, kk, n, m).view(b, n, m, 6, 1) C = safe_scatter_add_vec(torch.matmul(wJzT, Jz), kk, m) v = safe_scatter_add_vec(vi, ii, n).view(b, n, 1, 6, 1) + \ safe_scatter_add_vec(vj, jj, n).view(b, n, 1, 6, 1) w = safe_scatter_add_vec(torch.matmul(wJzT, r), kk, m) if isinstance(lmbda, torch.Tensor): lmbda = lmbda.reshape(*C.shape) Q = 1.0 / (C + lmbda) ### solve w/ schur complement ### EQ = E * Q[:,None] if structure_only or n == 0: dZ = (Q * w).view(b, -1, 1, 1) else: S = B - block_matmul(EQ, E.permute(0,2,1,4,3)) y = v - block_matmul(EQ, w.unsqueeze(dim=2)) dX = block_solve(S, y, ep=ep, lm=1e-4) dZ = Q * (w - block_matmul(E.permute(0,2,1,4,3), dX).squeeze(dim=-1)) dX = dX.view(b, -1, 6) dZ = dZ.view(b, -1, 1, 1) x, y, disps = patches.unbind(dim=2) disps = disp_retr(disps, dZ, kx).clamp(min=1e-3, max=10.0) patches = torch.stack([x, y, disps], dim=2) if not structure_only and n > 0: poses = pose_retr(poses, dX, fixedp + torch.arange(n)) return poses, patches