Spaces:
Sleeping
Sleeping
import torch | |
from torch_scatter import scatter_sum | |
from . import fastba | |
from . import lietorch | |
from .lietorch import SE3 | |
from .utils import Timer | |
from . import projective_ops as pops | |
class CholeskySolver(torch.autograd.Function): | |
def forward(ctx, H, b): | |
# don't crash training if cholesky decomp fails | |
U, info = torch.linalg.cholesky_ex(H) | |
if torch.any(info): | |
ctx.failed = True | |
return torch.zeros_like(b) | |
xs = torch.cholesky_solve(b, U) | |
ctx.save_for_backward(U, xs) | |
ctx.failed = False | |
return xs | |
def backward(ctx, grad_x): | |
if ctx.failed: | |
return None, None | |
U, xs = ctx.saved_tensors | |
dz = torch.cholesky_solve(grad_x, U) | |
dH = -torch.matmul(xs, dz.transpose(-1,-2)) | |
return dH, dz | |
# utility functions for scattering ops | |
def safe_scatter_add_mat(A, ii, jj, n, m): | |
v = (ii >= 0) & (jj >= 0) & (ii < n) & (jj < m) | |
return scatter_sum(A[:,v], ii[v]*m + jj[v], dim=1, dim_size=n*m) | |
def safe_scatter_add_vec(b, ii, n): | |
v = (ii >= 0) & (ii < n) | |
return scatter_sum(b[:,v], ii[v], dim=1, dim_size=n) | |
# apply retraction operator to inv-depth maps | |
def disp_retr(disps, dz, ii): | |
ii = ii.to(device=dz.device) | |
return disps + scatter_sum(dz, ii, dim=1, dim_size=disps.shape[1]) | |
# apply retraction operator to poses | |
def pose_retr(poses, dx, ii): | |
ii = ii.to(device=dx.device) | |
return poses.retr(scatter_sum(dx, ii, dim=1, dim_size=poses.shape[1])) | |
def block_matmul(A, B): | |
""" block matrix multiply """ | |
b, n1, m1, p1, q1 = A.shape | |
b, n2, m2, p2, q2 = B.shape | |
A = A.permute(0, 1, 3, 2, 4).reshape(b, n1*p1, m1*q1) | |
B = B.permute(0, 1, 3, 2, 4).reshape(b, n2*p2, m2*q2) | |
return torch.matmul(A, B).reshape(b, n1, p1, m2, q2).permute(0, 1, 3, 2, 4) | |
def block_solve(A, B, ep=1.0, lm=1e-4): | |
""" block matrix solve """ | |
b, n1, m1, p1, q1 = A.shape | |
b, n2, m2, p2, q2 = B.shape | |
A = A.permute(0, 1, 3, 2, 4).reshape(b, n1*p1, m1*q1) | |
B = B.permute(0, 1, 3, 2, 4).reshape(b, n2*p2, m2*q2) | |
A = A + (ep + lm * A) * torch.eye(n1*p1, device=A.device) | |
X = CholeskySolver.apply(A, B) | |
return X.reshape(b, n1, p1, m2, q2).permute(0, 1, 3, 2, 4) | |
def block_show(A): | |
import matplotlib.pyplot as plt | |
b, n1, m1, p1, q1 = A.shape | |
A = A.permute(0, 1, 3, 2, 4).reshape(b, n1*p1, m1*q1) | |
plt.imshow(A[0].detach().cpu().numpy()) | |
plt.show() | |
def BA(poses, patches, intrinsics, targets, weights, lmbda, ii, jj, kk, bounds, ep=100.0, PRINT=False, fixedp=1, structure_only=False): | |
""" bundle adjustment """ | |
b = 1 | |
n = max(ii.max().item(), jj.max().item()) + 1 | |
coords, v, (Ji, Jj, Jz) = \ | |
pops.transform(poses, patches, intrinsics, ii, jj, kk, jacobian=True) | |
p = coords.shape[3] | |
r = targets - coords[...,p//2,p//2,:] | |
v *= (r.norm(dim=-1) < 250).float() | |
in_bounds = \ | |
(coords[...,p//2,p//2,0] > bounds[0]) & \ | |
(coords[...,p//2,p//2,1] > bounds[1]) & \ | |
(coords[...,p//2,p//2,0] < bounds[2]) & \ | |
(coords[...,p//2,p//2,1] < bounds[3]) | |
v *= in_bounds.float() | |
if PRINT: | |
print((r * v[...,None]).norm(dim=-1).mean().item()) | |
r = (v[...,None] * r).unsqueeze(dim=-1) | |
weights = (v[...,None] * weights).unsqueeze(dim=-1) | |
wJiT = (weights * Ji).transpose(2,3) | |
wJjT = (weights * Jj).transpose(2,3) | |
wJzT = (weights * Jz).transpose(2,3) | |
Bii = torch.matmul(wJiT, Ji) | |
Bij = torch.matmul(wJiT, Jj) | |
Bji = torch.matmul(wJjT, Ji) | |
Bjj = torch.matmul(wJjT, Jj) | |
Eik = torch.matmul(wJiT, Jz) | |
Ejk = torch.matmul(wJjT, Jz) | |
vi = torch.matmul(wJiT, r) | |
vj = torch.matmul(wJjT, r) | |
# fix first pose | |
ii = ii.clone() | |
jj = jj.clone() | |
n = n - fixedp | |
ii = ii - fixedp | |
jj = jj - fixedp | |
kx, kk = torch.unique(kk, return_inverse=True, sorted=True) | |
m = len(kx) | |
B = safe_scatter_add_mat(Bii, ii, ii, n, n).view(b, n, n, 6, 6) + \ | |
safe_scatter_add_mat(Bij, ii, jj, n, n).view(b, n, n, 6, 6) + \ | |
safe_scatter_add_mat(Bji, jj, ii, n, n).view(b, n, n, 6, 6) + \ | |
safe_scatter_add_mat(Bjj, jj, jj, n, n).view(b, n, n, 6, 6) | |
E = safe_scatter_add_mat(Eik, ii, kk, n, m).view(b, n, m, 6, 1) + \ | |
safe_scatter_add_mat(Ejk, jj, kk, n, m).view(b, n, m, 6, 1) | |
C = safe_scatter_add_vec(torch.matmul(wJzT, Jz), kk, m) | |
v = safe_scatter_add_vec(vi, ii, n).view(b, n, 1, 6, 1) + \ | |
safe_scatter_add_vec(vj, jj, n).view(b, n, 1, 6, 1) | |
w = safe_scatter_add_vec(torch.matmul(wJzT, r), kk, m) | |
if isinstance(lmbda, torch.Tensor): | |
lmbda = lmbda.reshape(*C.shape) | |
Q = 1.0 / (C + lmbda) | |
### solve w/ schur complement ### | |
EQ = E * Q[:,None] | |
if structure_only or n == 0: | |
dZ = (Q * w).view(b, -1, 1, 1) | |
else: | |
S = B - block_matmul(EQ, E.permute(0,2,1,4,3)) | |
y = v - block_matmul(EQ, w.unsqueeze(dim=2)) | |
dX = block_solve(S, y, ep=ep, lm=1e-4) | |
dZ = Q * (w - block_matmul(E.permute(0,2,1,4,3), dX).squeeze(dim=-1)) | |
dX = dX.view(b, -1, 6) | |
dZ = dZ.view(b, -1, 1, 1) | |
x, y, disps = patches.unbind(dim=2) | |
disps = disp_retr(disps, dZ, kx).clamp(min=1e-3, max=10.0) | |
patches = torch.stack([x, y, disps], dim=2) | |
if not structure_only and n > 0: | |
poses = pose_retr(poses, dX, fixedp + torch.arange(n)) | |
return poses, patches | |