Spaces:
Running
on
T4
Running
on
T4
import os | |
import numpy as np | |
import torch | |
from scipy.spatial.transform import Rotation | |
MIN_EPS, MAX_EPS, N_EPS = 0.01, 2, 1000 | |
X_N = 2000 | |
""" | |
Preprocessing for the SO(3) sampling and score computations, truncated infinite series are computed and then | |
cached to memory, therefore the precomputation is only run the first time the repository is run on a machine | |
""" | |
omegas = np.linspace(0, np.pi, X_N + 1)[1:] | |
def _compose(r1, r2): # R1 @ R2 but for Euler vecs | |
return Rotation.from_matrix(Rotation.from_rotvec(r1).as_matrix() @ Rotation.from_rotvec(r2).as_matrix()).as_rotvec() | |
def _expansion(omega, eps, L=2000): # the summation term only | |
p = 0 | |
for l in range(L): | |
p += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) * np.sin(omega * (l + 1 / 2)) / np.sin(omega / 2) | |
return p | |
def _density(expansion, omega, marginal=True): # if marginal, density over [0, pi], else over SO(3) | |
if marginal: | |
return expansion * (1 - np.cos(omega)) / np.pi | |
else: | |
return expansion / 8 / np.pi ** 2 # the constant factor doesn't affect any actual calculations though | |
def _score(exp, omega, eps, L=2000): # score of density over SO(3) | |
dSigma = 0 | |
for l in range(L): | |
hi = np.sin(omega * (l + 1 / 2)) | |
dhi = (l + 1 / 2) * np.cos(omega * (l + 1 / 2)) | |
lo = np.sin(omega / 2) | |
dlo = 1 / 2 * np.cos(omega / 2) | |
dSigma += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) * (lo * dhi - hi * dlo) / lo ** 2 | |
return dSigma / exp | |
if os.path.exists('.so3_omegas_array2.npy'): | |
_omegas_array = np.load('.so3_omegas_array2.npy') | |
_cdf_vals = np.load('.so3_cdf_vals2.npy') | |
_score_norms = np.load('.so3_score_norms2.npy') | |
_exp_score_norms = np.load('.so3_exp_score_norms2.npy') | |
else: | |
_eps_array = 10 ** np.linspace(np.log10(MIN_EPS), np.log10(MAX_EPS), N_EPS) | |
_omegas_array = np.linspace(0, np.pi, X_N + 1)[1:] | |
_exp_vals = np.asarray([_expansion(_omegas_array, eps) for eps in _eps_array]) | |
_pdf_vals = np.asarray([_density(_exp, _omegas_array, marginal=True) for _exp in _exp_vals]) | |
_cdf_vals = np.asarray([_pdf.cumsum() / X_N * np.pi for _pdf in _pdf_vals]) | |
_score_norms = np.asarray([_score(_exp_vals[i], _omegas_array, _eps_array[i]) for i in range(len(_eps_array))]) | |
_exp_score_norms = np.sqrt(np.sum(_score_norms**2 * _pdf_vals, axis=1) / np.sum(_pdf_vals, axis=1) / np.pi) | |
np.save('.so3_omegas_array2.npy', _omegas_array) | |
np.save('.so3_cdf_vals2.npy', _cdf_vals) | |
np.save('.so3_score_norms2.npy', _score_norms) | |
np.save('.so3_exp_score_norms2.npy', _exp_score_norms) | |
def sample(eps): | |
eps_idx = (np.log10(eps) - np.log10(MIN_EPS)) / (np.log10(MAX_EPS) - np.log10(MIN_EPS)) * N_EPS | |
eps_idx = np.clip(np.around(eps_idx).astype(int), a_min=0, a_max=N_EPS - 1) | |
x = np.random.rand() | |
return np.interp(x, _cdf_vals[eps_idx], _omegas_array) | |
def sample_vec(eps): | |
x = np.random.randn(3) | |
x /= np.linalg.norm(x) | |
return x * sample(eps) | |
def score_vec(eps, vec): | |
eps_idx = (np.log10(eps) - np.log10(MIN_EPS)) / (np.log10(MAX_EPS) - np.log10(MIN_EPS)) * N_EPS | |
eps_idx = np.clip(np.around(eps_idx).astype(int), a_min=0, a_max=N_EPS - 1) | |
om = np.linalg.norm(vec) | |
return np.interp(om, _omegas_array, _score_norms[eps_idx]) * vec / om | |
def score_norm(eps): | |
eps = eps.numpy() | |
eps_idx = (np.log10(eps) - np.log10(MIN_EPS)) / (np.log10(MAX_EPS) - np.log10(MIN_EPS)) * N_EPS | |
eps_idx = np.clip(np.around(eps_idx).astype(int), a_min=0, a_max=N_EPS-1) | |
return torch.from_numpy(_exp_score_norms[eps_idx]).float() | |