Spaces:
Runtime error
Runtime error
File size: 4,660 Bytes
753e275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .geometry import quaternion_to_rotation_matrix
def log_rotation(R):
trace = R[..., range(3), range(3)].sum(-1)
if torch.is_grad_enabled():
# The derivative of acos at -1.0 is -inf, so to stablize the gradient, we use -0.9999
min_cos = -0.999
else:
min_cos = -1.0
cos_theta = ( (trace-1) / 2 ).clamp_min(min=min_cos)
sin_theta = torch.sqrt(1 - cos_theta**2)
theta = torch.acos(cos_theta)
coef = ((theta+1e-8)/(2*sin_theta+2e-8))[..., None, None]
logR = coef * (R - R.transpose(-1, -2))
return logR
def skewsym_to_so3vec(S):
x = S[..., 1, 2]
y = S[..., 2, 0]
z = S[..., 0, 1]
w = torch.stack([x,y,z], dim=-1)
return w
def so3vec_to_skewsym(w):
x, y, z = torch.unbind(w, dim=-1)
o = torch.zeros_like(x)
S = torch.stack([
o, z, -y,
-z, o, x,
y, -x, o,
], dim=-1).reshape(w.shape[:-1] + (3, 3))
return S
def exp_skewsym(S):
x = torch.linalg.norm(skewsym_to_so3vec(S), dim=-1)
I = torch.eye(3).to(S).view([1 for _ in range(S.dim()-2)] + [3, 3])
sinx, cosx = torch.sin(x), torch.cos(x)
b = (sinx + 1e-8) / (x + 1e-8)
c = (1-cosx + 1e-8) / (x**2 + 2e-8) # lim_{x->0} (1-cosx)/(x^2) = 0.5
S2 = S @ S
return I + b[..., None, None]*S + c[..., None, None]*S2
def so3vec_to_rotation(w):
return exp_skewsym(so3vec_to_skewsym(w))
def rotation_to_so3vec(R):
logR = log_rotation(R)
w = skewsym_to_so3vec(logR)
return w
def random_uniform_so3(size, device='cpu'):
q = F.normalize(torch.randn(list(size)+[4,], device=device), dim=-1) # (..., 4)
return rotation_to_so3vec(quaternion_to_rotation_matrix(q))
class ApproxAngularDistribution(nn.Module):
def __init__(self, stddevs, std_threshold=0.1, num_bins=8192, num_iters=1024):
super().__init__()
self.std_threshold = std_threshold
self.num_bins = num_bins
self.num_iters = num_iters
self.register_buffer('stddevs', torch.FloatTensor(stddevs))
self.register_buffer('approx_flag', self.stddevs <= std_threshold)
self._precompute_histograms()
@staticmethod
def _pdf(x, e, L):
"""
Args:
x: (N, )
e: Float
L: Integer
"""
x = x[:, None] # (N, *)
c = ((1 - torch.cos(x)) / math.pi) # (N, *)
l = torch.arange(0, L)[None, :] # (*, L)
a = (2*l+1) * torch.exp(-l*(l+1)*(e**2)) # (*, L)
b = (torch.sin( (l+0.5)* x ) + 1e-6) / (torch.sin( x / 2 ) + 1e-6) # (N, L)
f = (c * a * b).sum(dim=1)
return f
def _precompute_histograms(self):
X, Y = [], []
for std in self.stddevs:
std = std.item()
x = torch.linspace(0, math.pi, self.num_bins) # (n_bins,)
y = self._pdf(x, std, self.num_iters) # (n_bins,)
y = torch.nan_to_num(y).clamp_min(0)
X.append(x)
Y.append(y)
self.register_buffer('X', torch.stack(X, dim=0)) # (n_stddevs, n_bins)
self.register_buffer('Y', torch.stack(Y, dim=0)) # (n_stddevs, n_bins)
def sample(self, std_idx):
"""
Args:
std_idx: Indices of standard deviation.
Returns:
samples: Angular samples [0, PI), same size as std.
"""
size = std_idx.size()
std_idx = std_idx.flatten() # (N,)
# Samples from histogram
prob = self.Y[std_idx] # (N, n_bins)
bin_idx = torch.multinomial(prob[:, :-1], num_samples=1).squeeze(-1) # (N,)
bin_start = self.X[std_idx, bin_idx] # (N,)
bin_width = self.X[std_idx, bin_idx+1] - self.X[std_idx, bin_idx]
samples_hist = bin_start + torch.rand_like(bin_start) * bin_width # (N,)
# Samples from Gaussian approximation
mean_gaussian = self.stddevs[std_idx]*2
std_gaussian = self.stddevs[std_idx]
samples_gaussian = mean_gaussian + torch.randn_like(mean_gaussian) * std_gaussian
samples_gaussian = samples_gaussian.abs() % math.pi
# Choose from histogram or Gaussian
gaussian_flag = self.approx_flag[std_idx]
samples = torch.where(gaussian_flag, samples_gaussian, samples_hist)
return samples.reshape(size)
def random_normal_so3(std_idx, angular_distrib, device='cpu'):
size = std_idx.size()
u = F.normalize(torch.randn(list(size)+[3,], device=device), dim=-1)
theta = angular_distrib.sample(std_idx)
w = u * theta[..., None]
return w
|