Spaces:
Sleeping
Sleeping
import torch | |
import lietorch | |
from lietorch import SO3, RxSO3, SE3, Sim3 | |
from gradcheck import gradcheck, get_analytical_jacobian | |
### forward tests ### | |
def make_homogeneous(p): | |
return torch.cat([p, torch.ones_like(p[...,:1])], dim=-1) | |
def matv(A, b): | |
return torch.matmul(A, b[...,None])[..., 0] | |
def test_exp_log(Group, device='cuda'): | |
""" check Log(Exp(x)) == x """ | |
a = .2*torch.randn(2,3,4,5,6,7,Group.manifold_dim, device=device).double() | |
b = Group.exp(a).log() | |
assert torch.allclose(a,b,atol=1e-8), "should be identity" | |
print("\t-", Group, "Passed exp-log test") | |
def test_inv(Group, device='cuda'): | |
""" check X * X^{-1} == 0 """ | |
X = Group.exp(.1*torch.randn(2,3,4,5,Group.manifold_dim, device=device).double()) | |
a = (X * X.inv()).log() | |
assert torch.allclose(a, torch.zeros_like(a), atol=1e-8), "should be 0" | |
print("\t-", Group, "Passed inv test") | |
def test_adj(Group, device='cuda'): | |
""" check X * Exp(a) == Exp(Adj(X,a)) * X 0 """ | |
X = Group.exp(torch.randn(2,3,4,5, Group.manifold_dim, device=device).double()) | |
a = torch.randn(2,3,4,5, Group.manifold_dim, device=device).double() | |
b = X.adj(a) | |
Y1 = X * Group.exp(a) | |
Y2 = Group.exp(b) * X | |
c = (Y1 * Y2.inv()).log() | |
assert torch.allclose(c, torch.zeros_like(c), atol=1e-8), "should be 0" | |
print("\t-", Group, "Passed adj test") | |
def test_act(Group, device='cuda'): | |
X = Group.exp(torch.randn(1, Group.manifold_dim, device=device).double()) | |
p = torch.randn(1,3,device=device).double() | |
p1 = X.act(p) | |
p2 = matv(X.matrix(), make_homogeneous(p)) | |
assert torch.allclose(p1, p2[...,:3], atol=1e-8), "should be 0" | |
print("\t-", Group, "Passed act test") | |
### backward tests ### | |
def test_exp_log_grad(Group, device='cuda', tol=1e-8): | |
D = Group.manifold_dim | |
def fn(a): | |
return Group.exp(a).log() | |
a = torch.zeros(1, Group.manifold_dim, requires_grad=True, device=device).double() | |
analytical, reentrant, correct_grad_sizes, correct_grad_types = \ | |
get_analytical_jacobian((a,), fn(a)) | |
assert torch.allclose(analytical[0], torch.eye(D, device=device).double(), atol=tol) | |
a = .2 * torch.randn(1, Group.manifold_dim, requires_grad=True, device=device).double() | |
analytical, reentrant, correct_grad_sizes, correct_grad_types = \ | |
get_analytical_jacobian((a,), fn(a)) | |
assert torch.allclose(analytical[0], torch.eye(D, device=device).double(), atol=tol) | |
print("\t-", Group, "Passed eye-grad test") | |
def test_inv_log_grad(Group, device='cuda', tol=1e-8): | |
D = Group.manifold_dim | |
X = Group.exp(.2*torch.randn(1,D,device=device).double()) | |
def fn(a): | |
return (Group.exp(a) * X).inv().log() | |
a = torch.zeros(1, D, requires_grad=True, device=device).double() | |
analytical, numerical = gradcheck(fn, [a], eps=1e-4) | |
# assert torch.allclose(analytical[0], numerical[0], atol=tol) | |
if not torch.allclose(analytical[0], numerical[0], atol=tol): | |
print(analytical[0]) | |
print(numerical[0]) | |
print("\t-", Group, "Passed inv-grad test") | |
def test_adj_grad(Group, device='cuda'): | |
D = Group.manifold_dim | |
X = Group.exp(.5*torch.randn(1,Group.manifold_dim, device=device).double()) | |
def fn(a, b): | |
return (Group.exp(a) * X).adj(b) | |
a = torch.zeros(1, D, requires_grad=True, device=device).double() | |
b = torch.randn(1, D, requires_grad=True, device=device).double() | |
analytical, numerical = gradcheck(fn, [a, b], eps=1e-4) | |
assert torch.allclose(analytical[0], numerical[0], atol=1e-8) | |
assert torch.allclose(analytical[1], numerical[1], atol=1e-8) | |
print("\t-", Group, "Passed adj-grad test") | |
def test_adjT_grad(Group, device='cuda'): | |
D = Group.manifold_dim | |
X = Group.exp(.5*torch.randn(1,Group.manifold_dim, device=device).double()) | |
def fn(a, b): | |
return (Group.exp(a) * X).adjT(b) | |
a = torch.zeros(1, D, requires_grad=True, device=device).double() | |
b = torch.randn(1, D, requires_grad=True, device=device).double() | |
analytical, numerical = gradcheck(fn, [a, b], eps=1e-4) | |
assert torch.allclose(analytical[0], numerical[0], atol=1e-8) | |
assert torch.allclose(analytical[1], numerical[1], atol=1e-8) | |
print("\t-", Group, "Passed adjT-grad test") | |
def test_act_grad(Group, device='cuda'): | |
D = Group.manifold_dim | |
X = Group.exp(5*torch.randn(1,D, device=device).double()) | |
def fn(a, b): | |
return (X*Group.exp(a)).act(b) | |
a = torch.zeros(1, D, requires_grad=True, device=device).double() | |
b = torch.randn(1, 3, requires_grad=True, device=device).double() | |
analytical, numerical = gradcheck(fn, [a, b], eps=1e-4) | |
assert torch.allclose(analytical[0], numerical[0], atol=1e-8) | |
assert torch.allclose(analytical[1], numerical[1], atol=1e-8) | |
print("\t-", Group, "Passed act-grad test") | |
def test_matrix_grad(Group, device='cuda'): | |
D = Group.manifold_dim | |
X = Group.exp(torch.randn(1, D, device=device).double()) | |
def fn(a): | |
return (Group.exp(a) * X).matrix() | |
a = torch.zeros(1, D, requires_grad=True, device=device).double() | |
analytical, numerical = gradcheck(fn, [a], eps=1e-4) | |
assert torch.allclose(analytical[0], numerical[0], atol=1e-6) | |
print("\t-", Group, "Passed matrix-grad test") | |
def extract_translation_grad(Group, device='cuda'): | |
""" prototype function """ | |
D = Group.manifold_dim | |
X = Group.exp(5*torch.randn(1,D, device=device).double()) | |
def fn(a): | |
return (Group.exp(a)*X).translation() | |
a = torch.zeros(1, D, requires_grad=True, device=device).double() | |
analytical, numerical = gradcheck(fn, [a], eps=1e-4) | |
assert torch.allclose(analytical[0], numerical[0], atol=1e-8) | |
print("\t-", Group, "Passed translation grad test") | |
def test_vec_grad(Group, device='cuda', tol=1e-6): | |
D = Group.manifold_dim | |
X = Group.exp(5*torch.randn(1,D, device=device).double()) | |
def fn(a): | |
return (Group.exp(a)*X).vec() | |
a = torch.zeros(1, D, requires_grad=True, device=device).double() | |
analytical, numerical = gradcheck(fn, [a], eps=1e-4) | |
assert torch.allclose(analytical[0], numerical[0], atol=tol) | |
print("\t-", Group, "Passed tovec grad test") | |
def test_fromvec_grad(Group, device='cuda', tol=1e-6): | |
def fn(a): | |
if Group == SO3: | |
a = a / a.norm(dim=-1, keepdim=True) | |
elif Group == RxSO3: | |
q, s = a.split([4, 1], dim=-1) | |
q = q / q.norm(dim=-1, keepdim=True) | |
a = torch.cat([q, s.exp()], dim=-1) | |
elif Group == SE3: | |
t, q = a.split([3, 4], dim=-1) | |
q = q / q.norm(dim=-1, keepdim=True) | |
a = torch.cat([t, q], dim=-1) | |
elif Group == Sim3: | |
t, q, s = a.split([3, 4, 1], dim=-1) | |
q = q / q.norm(dim=-1, keepdim=True) | |
a = torch.cat([t, q, s.exp()], dim=-1) | |
return Group.InitFromVec(a).vec() | |
D = Group.embedded_dim | |
a = torch.randn(1, 2, D, requires_grad=True, device=device).double() | |
analytical, numerical = gradcheck(fn, [a], eps=1e-4) | |
assert torch.allclose(analytical[0], numerical[0], atol=tol) | |
print("\t-", Group, "Passed fromvec grad test") | |
def scale(device='cuda'): | |
def fn(a, s): | |
X = SE3.exp(a) | |
X.scale(s) | |
return X.log() | |
s = torch.rand(1, requires_grad=True, device=device).double() | |
a = torch.randn(1, 6, requires_grad=True, device=device).double() | |
analytical, numerical = gradcheck(fn, [a, s], eps=1e-3) | |
print(analytical[1]) | |
print(numerical[1]) | |
assert torch.allclose(analytical[0], numerical[0], atol=1e-8) | |
assert torch.allclose(analytical[1], numerical[1], atol=1e-8) | |
print("\t-", "Passed se3-to-sim3 test") | |
if __name__ == '__main__': | |
print("Testing lietorch forward pass (CPU) ...") | |
for Group in [SO3, RxSO3, SE3, Sim3]: | |
test_exp_log(Group, device='cpu') | |
test_inv(Group, device='cpu') | |
test_adj(Group, device='cpu') | |
test_act(Group, device='cpu') | |
print("Testing lietorch backward pass (CPU)...") | |
for Group in [SO3, RxSO3, SE3, Sim3]: | |
if Group == Sim3: | |
tol = 1e-3 | |
else: | |
tol = 1e-8 | |
test_exp_log_grad(Group, device='cpu', tol=tol) | |
test_inv_log_grad(Group, device='cpu', tol=tol) | |
test_adj_grad(Group, device='cpu') | |
test_adjT_grad(Group, device='cpu') | |
test_act_grad(Group, device='cpu') | |
test_matrix_grad(Group, device='cpu') | |
extract_translation_grad(Group, device='cpu') | |
test_vec_grad(Group, device='cpu') | |
test_fromvec_grad(Group, device='cpu') | |
print("Testing lietorch forward pass (GPU) ...") | |
for Group in [SO3, RxSO3, SE3, Sim3]: | |
test_exp_log(Group, device='cuda') | |
test_inv(Group, device='cuda') | |
test_adj(Group, device='cuda') | |
test_act(Group, device='cuda') | |
print("Testing lietorch backward pass (GPU)...") | |
for Group in [SO3, RxSO3, SE3, Sim3]: | |
if Group == Sim3: | |
tol = 1e-3 | |
else: | |
tol = 1e-8 | |
test_exp_log_grad(Group, device='cuda', tol=tol) | |
test_inv_log_grad(Group, device='cuda', tol=tol) | |
test_adj_grad(Group, device='cuda') | |
test_adjT_grad(Group, device='cuda') | |
test_act_grad(Group, device='cuda') | |
test_matrix_grad(Group, device='cuda') | |
extract_translation_grad(Group, device='cuda') | |
test_vec_grad(Group, device='cuda') | |
test_fromvec_grad(Group, device='cuda') | |