DrugFlow / src /data /so3_utils.py
mority's picture
Upload 53 files
6e7d4ba verified
import math
import torch
def _batch_trace(m):
return torch.einsum('...ii', m)
def regularize(point, eps=1e-6):
"""
Norm of the rotation vector should be between 0 and pi.
Inverts the direction of the rotation axis if the value is between pi and 2 pi.
Args:
point, (n, 3)
Returns:
regularized point, (n, 3)
"""
theta = torch.linalg.norm(point, axis=-1)
# angle in [0, 2pi)
theta_wrapped = theta % (2 * math.pi)
inv_mask = theta_wrapped > math.pi
# angle in [0, pi) & invert
theta_wrapped[inv_mask] = -1 * (2 * math.pi - theta_wrapped[inv_mask])
# apply
theta = torch.clamp(theta, min=eps)
point = point * (theta_wrapped / theta).unsqueeze(-1)
assert not point.isnan().any()
return point
def random_uniform(n_samples, device=None):
"""
Follow geomstats implementation:
https://geomstats.github.io/_modules/geomstats/geometry/special_orthogonal.html
Args:
n_samples: int
Returns:
rotation vectors, (n, 3)
"""
random_point = (torch.rand(n_samples, 3, device=device) * 2 - 1) * math.pi
random_point = regularize(random_point)
return random_point
def hat(rot_vec):
"""
Maps R^3 vector to a skew-symmetric matrix r (i.e. r \in R^{3x3} and r^T = -r).
Since we have the identity rv = rot_vec x v for all v \in R^3, this is
identical to a cross-product-matrix representation of rot_vec.
rot_vec x v = hat(rot_vec)^T v
See also:
https://en.wikipedia.org/wiki/Cross_product#Conversion_to_matrix_multiplication
https://en.wikipedia.org/wiki/Hat_notation#Cross_product
Args:
rot_vec: (n, 3)
Returns:
skew-symmetric matrices (n, 3, 3)
"""
basis = torch.tensor([
[[0., 0., 0.], [0., 0., -1.], [0., 1., 0.]],
[[0., 0., 1.], [0., 0., 0.], [-1., 0., 0.]],
[[0., -1., 0.], [1., 0., 0.], [0., 0., 0.]]
], device=rot_vec.device)
# basis = torch.tensor([
# [[0., 0., 0.], [0., 0., 1.], [0., -1., 0.]],
# [[0., 0., -1.], [0., 0., 0.], [1., 0., 0.]],
# [[0., 1., 0.], [-1., 0., 0.], [0., 0., 0.]]
# ], device=rot_vec.device)
return torch.einsum('...i,ijk->...jk', rot_vec, basis)
def inv_hat(skew_mat):
"""
Inverse of hat operation
Args:
skew_mat: skew-symmetric matrices (n, 3, 3)
Returns:
rotation vectors, (n, 3)
"""
assert torch.allclose(-skew_mat, skew_mat.transpose(-2, -1), atol=1e-4), \
f"Input not skew-symmetric (err={(-skew_mat - skew_mat.transpose(-2, -1)).abs().max():.4g})"
# vec = torch.stack([
# skew_mat[:, 1, 2],
# skew_mat[:, 2, 1],
# skew_mat[:, 0, 1]
# ], dim=1)
vec = torch.stack([
skew_mat[:, 2, 1],
skew_mat[:, 0, 2],
skew_mat[:, 1, 0]
], dim=1)
return vec
def matrix_from_rotation_vector(axis_angle, eps=1e-6):
"""
Args:
axis_angle: (n, 3)
Returns:
rotation matrices, (n, 3, 3)
"""
axis_angle = regularize(axis_angle)
angle = axis_angle.norm(dim=-1)
_norm = torch.clamp(angle, min=eps).unsqueeze(-1)
skew_mat = hat(axis_angle / _norm)
# https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula#Matrix_notation
_id = torch.eye(3, device=axis_angle.device).unsqueeze(0)
rot_mat = _id + \
torch.sin(angle)[:, None, None] * skew_mat + \
(1 - torch.cos(angle))[:, None, None] * torch.bmm(skew_mat, skew_mat)
return rot_mat
class safe_acos(torch.autograd.Function):
"""
Implementation of arccos that avoids NaN in backward pass.
https://github.com/pytorch/pytorch/issues/8069#issuecomment-2041223872
"""
EPS = 1e-4
@classmethod
def d_acos_dx(cls, x):
x = torch.clamp(x, min=-1. + cls.EPS, max=1. - cls.EPS)
return -1.0 / (1 - x**2).sqrt()
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.acos()
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output * safe_acos.d_acos_dx(input)
def rotation_vector_from_matrix(rot_mat, approx=1e-4):
"""
Args:
rot_mat: (n, 3, 3)
approx: float, minimum angle below which an approximation will be used
for numerical stability
Returns:
rotation vector, (n, 3)
"""
# https://en.wikipedia.org/wiki/Rotation_matrix#Conversion_from_rotation_matrix_to_axis%E2%80%93angle
# https://en.wikipedia.org/wiki/Axis%E2%80%93angle_representation#Log_map_from_SO(3)_to_%F0%9D%94%B0%F0%9D%94%AC(3)
# determine axis
skew_mat = rot_mat - rot_mat.transpose(-2, -1)
# determine the angle
cos_angle = 0.5 * (_batch_trace(rot_mat) - 1)
# arccos is only defined between -1 and 1
assert torch.all(cos_angle.abs() <= 1 + 1e-6)
cos_angle = torch.clamp(cos_angle, min=-1., max=1.)
# abs_angle = torch.arccos(cos_angle)
abs_angle = safe_acos.apply(cos_angle)
# avoid numerical instability; use sin(x) \approx x for small x
close_to_0 = abs_angle < approx
_fac = torch.empty_like(abs_angle)
_fac[close_to_0] = 0.5
_fac[~close_to_0] = 0.5 * abs_angle[~close_to_0] / torch.sin(abs_angle[~close_to_0])
axis_angle = inv_hat(_fac[:, None, None] * skew_mat)
return regularize(axis_angle)
def get_jacobian(point, left=True, inverse=False, eps=1e-4):
# # From Geomstats: https://geomstats.github.io/_modules/geomstats/geometry/special_orthogonal.html
# jacobian = so3_vector.jacobian_translation(point, left)
#
# if inverse:
# jacobian = torch.linalg.inv(jacobian)
# Right Jacobian defined as J_r(theta) = \partial exp([theta]_x) / \partial theta
# https://math.stackexchange.com/questions/301533/jacobian-involving-so3-exponential-map-logr-expm
# Source:
# Chirikjian, Gregory S. Stochastic models, information theory, and Lie
# groups, volume 2: Analytic methods and modern applications. Vol. 2.
# Springer Science & Business Media, 2011. (page 40)
# NOTE: the definitions of 'inverse' and 'left' in the book are the opposite
# of their meanings in Geomstats, whose functionality we're mimicking here.
# This explains the differences in the equations.
angle_squared = point.square().sum(-1)
angle = angle_squared.sqrt()
skew_mat = hat(point)
assert torch.all(angle <= math.pi)
close_to_0 = angle < eps
close_to_pi = (math.pi - angle) < eps
angle = angle[:, None, None]
angle_squared = angle_squared[:, None, None]
if inverse:
# _jacobian = torch.eye(3, device=point.device).unsqueeze(0) + \
# (1 - torch.cos(angle)) / angle_squared * skew_mat + \
# (angle - torch.sin(angle)) / angle ** 3 * (skew_mat @ skew_mat)
_term1 = torch.empty_like(angle)
_term1[close_to_0] = 0.5 # approximate with value at zero
_term1[~close_to_0] = (1 - torch.cos(angle)) / angle_squared
_term2 = torch.empty_like(angle)
_term2[close_to_0] = 1 / 6 # approximate with value at zero
_term2[~close_to_0] = (angle - torch.sin(angle)) / angle ** 3
jacobian = torch.eye(3, device=point.device).unsqueeze(0) + \
_term1 * skew_mat + _term2 * (skew_mat @ skew_mat)
# assert torch.allclose(jacobian, _jacobian, atol=1e-4)
else:
# _jacobian = torch.eye(3, device=point.device).unsqueeze(0) - 0.5 * skew_mat + \
# (1 / angle_squared - (1 + torch.cos(angle)) / (2 * angle * torch.sin(angle))) * (skew_mat @ skew_mat)
_term1 = torch.empty_like(angle)
_term1[close_to_0] = 1 / 12 # approximate with value at zero
_term1[close_to_pi] = 1 / math.pi**2 # approximate with value at pi
default = ~close_to_0 & ~close_to_pi
_term1[default] = 1 / angle_squared[default] - \
(1 + torch.cos(angle[default])) / (2 * angle[default] * torch.sin(angle[default]))
jacobian = torch.eye(3, device=point.device).unsqueeze(0) - \
0.5 * skew_mat + _term1 * (skew_mat @ skew_mat)
# assert torch.allclose(jacobian, _jacobian, atol=1e-4)
if left:
jacobian = jacobian.transpose(-2, -1)
return jacobian
def compose_rotations(rot_vec_1, rot_vec_2):
rot_mat_1 = matrix_from_rotation_vector(rot_vec_1)
rot_mat_2 = matrix_from_rotation_vector(rot_vec_2)
rot_mat_out = torch.bmm(rot_mat_1, rot_mat_2)
return rotation_vector_from_matrix(rot_mat_out)
def exp(tangent):
"""
Exponential map at identity.
Args:
tangent: vector on the tangent space, (n, 3)
Returns:
rotation vector on the manifold, (n, 3)
"""
# rotations are already represented by rotation vectors
exp_from_identity = regularize(tangent)
return exp_from_identity
def exp_not_from_identity(tangent_vec, base_point):
"""
Exponential map at base point.
Args:
tangent_vec: vector on the tangent plane, (n, 3)
base_point: base point on the manifold, (n, 3)
Returns:
new point on the manifold, (n, 3)
"""
tangent_vec = regularize(tangent_vec)
base_point = regularize(base_point)
# Lie algebra is the tangent space at the identity element of a Lie group
# -> to identity
jacobian = get_jacobian(base_point, left=True, inverse=True)
tangent_vec_at_id = torch.einsum("...ij,...j->...i", jacobian, tangent_vec)
# exponential map from identity
exp_from_identity = exp(tangent_vec_at_id)
# -> back to base point
return compose_rotations(base_point, exp_from_identity)
def log(rot_vec, as_skew=False):
"""
Logarithm map from tangent space at the identity.
Args:
rot_vec: point on the manifold, (n, 3)
Returns:
vector on the tangent space, (n, 3)
"""
# rotations are already represented by rotation vectors
# log_from_id = regularize(rot_vec)
log_from_id = rot_vec
if as_skew:
log_from_id = hat(log_from_id)
return log_from_id
def log_not_from_identity(point, base_point):
"""
Logarithm map of point from base point.
Args:
point: point on the manifold, (n, 3)
base_point: base point on the manifold, (n, 3)
Returns:
vector on the tangent plane, (n, 3)
"""
point = regularize(point)
base_point = regularize(base_point)
inv_base_point = -1 * base_point
point_near_id = compose_rotations(inv_base_point, point)
# logarithm map from identity
log_from_id = log(point_near_id)
jacobian = get_jacobian(base_point, inverse=False)
tangent_vec_at_id = torch.einsum("...ij,...j->...i", jacobian, log_from_id)
return tangent_vec_at_id
if __name__ == "__main__":
import os
os.environ['GEOMSTATS_BACKEND'] = "pytorch"
import scipy.optimize # does not seem to be imported correctly when just loading geomstats
default_dtype = torch.get_default_dtype()
from geomstats.geometry.special_orthogonal import SpecialOrthogonal
torch.set_default_dtype(default_dtype) # Geomstats changes default type when imported
so3_vector = SpecialOrthogonal(n=3, point_type="vector")
# decorator
if torch.__version__ >= '2.0.0':
GEOMSTATS_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def geomstats_tensor_type(func):
def inner(*args, **kwargs):
with torch.device(GEOMSTATS_DEVICE):
out = func(*args, **kwargs)
return out
return inner
else:
GEOMSTATS_TENSOR_TYPE = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor'
# GEOMSTATS_TENSOR_TYPE = 'torch.cuda.DoubleTensor' if torch.cuda.is_available() else 'torch.DoubleTensor'
def geomstats_tensor_type(func):
def inner(*args, **kwargs):
# tensor_type_before = TODO
torch.set_default_tensor_type(GEOMSTATS_TENSOR_TYPE)
out = func(*args, **kwargs)
# torch.set_default_tensor_type(tensor_type_before)
torch.set_default_tensor_type('torch.FloatTensor')
return out
return inner
@geomstats_tensor_type
def gs_matrix_from_rotation_vector(*args, **kwargs):
return so3_vector.matrix_from_rotation_vector(*args, **kwargs)
@geomstats_tensor_type
def gs_rotation_vector_from_matrix(*args, **kwargs):
return so3_vector.rotation_vector_from_matrix(*args, **kwargs)
@geomstats_tensor_type
def gs_exp_not_from_identity(*args, **kwargs):
return so3_vector.exp_not_from_identity(*args, **kwargs)
@geomstats_tensor_type
def gs_log_not_from_identity(*args, **kwargs):
# norm of the rotation vector will be between 0 and pi
return so3_vector.log_not_from_identity(*args, **kwargs)
@geomstats_tensor_type
def compose(*args, **kwargs):
return so3_vector.compose(*args, **kwargs)
@geomstats_tensor_type
def inverse(*args, **kwargs):
return so3_vector.inverse(*args, **kwargs)
@geomstats_tensor_type
def gs_random_uniform(*args, **kwargs):
return so3_vector.random_uniform(*args, **kwargs)
#############
# RUN TESTS #
#############
n = 16
device = 'cuda' if torch.cuda.is_available() else None
### regularize ###
# vec = (torch.rand(n, 3) * 2 - 1) * math.pi
vec = (torch.rand(n, 3) * 4 - 2) * math.pi
axis_angle = regularize(vec)
assert torch.all(torch.cross(vec, axis_angle).norm(dim=-1) < 1e-5), "not all vectors collinear"
assert torch.all(axis_angle.norm(dim=-1) < math.pi) & torch.all(axis_angle.norm(dim=-1) >= 0), "norm not between 0 and pi"
### matrix_from_rotation_vector ###
rot_vec = random_uniform(16, device=device)
assert torch.allclose(matrix_from_rotation_vector(rot_vec),
gs_matrix_from_rotation_vector(rot_vec), atol=1e-06)
### rotation_vector_from_matrix ###
rot_vec = random_uniform(16, device=device)
rot_mat = matrix_from_rotation_vector(rot_vec)
assert torch.allclose(rotation_vector_from_matrix(rot_mat),
gs_rotation_vector_from_matrix(rot_mat), atol=1e-05)
### exp_not_from_identity ###
tangent_vec = random_uniform(16, device=device)
base_pt = random_uniform(16, device=device)
my_val = exp_not_from_identity(tangent_vec, base_pt)
gs_val = gs_exp_not_from_identity(tangent_vec, base_pt)
assert torch.allclose(my_val, gs_val, atol=1e-03), (my_val - gs_val).abs().max()
### log_not_from_identity ###
pt = random_uniform(16, device=device)
base_pt = random_uniform(16, device=device)
my_val = log_not_from_identity(pt, base_pt)
gs_val = gs_log_not_from_identity(pt, base_pt)
assert torch.allclose(my_val, gs_val, atol=1e-03), (my_val - gs_val).abs().max()
print("All tests successful!")