diffdock / utils /so3.py
gcorso's picture
first commit
4a3f787
raw
history blame
3.59 kB
import os
import numpy as np
import torch
from scipy.spatial.transform import Rotation
MIN_EPS, MAX_EPS, N_EPS = 0.01, 2, 1000
X_N = 2000
"""
Preprocessing for the SO(3) sampling and score computations, truncated infinite series are computed and then
cached to memory, therefore the precomputation is only run the first time the repository is run on a machine
"""
omegas = np.linspace(0, np.pi, X_N + 1)[1:]
def _compose(r1, r2): # R1 @ R2 but for Euler vecs
return Rotation.from_matrix(Rotation.from_rotvec(r1).as_matrix() @ Rotation.from_rotvec(r2).as_matrix()).as_rotvec()
def _expansion(omega, eps, L=2000): # the summation term only
p = 0
for l in range(L):
p += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) * np.sin(omega * (l + 1 / 2)) / np.sin(omega / 2)
return p
def _density(expansion, omega, marginal=True): # if marginal, density over [0, pi], else over SO(3)
if marginal:
return expansion * (1 - np.cos(omega)) / np.pi
else:
return expansion / 8 / np.pi ** 2 # the constant factor doesn't affect any actual calculations though
def _score(exp, omega, eps, L=2000): # score of density over SO(3)
dSigma = 0
for l in range(L):
hi = np.sin(omega * (l + 1 / 2))
dhi = (l + 1 / 2) * np.cos(omega * (l + 1 / 2))
lo = np.sin(omega / 2)
dlo = 1 / 2 * np.cos(omega / 2)
dSigma += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) * (lo * dhi - hi * dlo) / lo ** 2
return dSigma / exp
if os.path.exists('.so3_omegas_array2.npy'):
_omegas_array = np.load('.so3_omegas_array2.npy')
_cdf_vals = np.load('.so3_cdf_vals2.npy')
_score_norms = np.load('.so3_score_norms2.npy')
_exp_score_norms = np.load('.so3_exp_score_norms2.npy')
else:
_eps_array = 10 ** np.linspace(np.log10(MIN_EPS), np.log10(MAX_EPS), N_EPS)
_omegas_array = np.linspace(0, np.pi, X_N + 1)[1:]
_exp_vals = np.asarray([_expansion(_omegas_array, eps) for eps in _eps_array])
_pdf_vals = np.asarray([_density(_exp, _omegas_array, marginal=True) for _exp in _exp_vals])
_cdf_vals = np.asarray([_pdf.cumsum() / X_N * np.pi for _pdf in _pdf_vals])
_score_norms = np.asarray([_score(_exp_vals[i], _omegas_array, _eps_array[i]) for i in range(len(_eps_array))])
_exp_score_norms = np.sqrt(np.sum(_score_norms**2 * _pdf_vals, axis=1) / np.sum(_pdf_vals, axis=1) / np.pi)
np.save('.so3_omegas_array2.npy', _omegas_array)
np.save('.so3_cdf_vals2.npy', _cdf_vals)
np.save('.so3_score_norms2.npy', _score_norms)
np.save('.so3_exp_score_norms2.npy', _exp_score_norms)
def sample(eps):
eps_idx = (np.log10(eps) - np.log10(MIN_EPS)) / (np.log10(MAX_EPS) - np.log10(MIN_EPS)) * N_EPS
eps_idx = np.clip(np.around(eps_idx).astype(int), a_min=0, a_max=N_EPS - 1)
x = np.random.rand()
return np.interp(x, _cdf_vals[eps_idx], _omegas_array)
def sample_vec(eps):
x = np.random.randn(3)
x /= np.linalg.norm(x)
return x * sample(eps)
def score_vec(eps, vec):
eps_idx = (np.log10(eps) - np.log10(MIN_EPS)) / (np.log10(MAX_EPS) - np.log10(MIN_EPS)) * N_EPS
eps_idx = np.clip(np.around(eps_idx).astype(int), a_min=0, a_max=N_EPS - 1)
om = np.linalg.norm(vec)
return np.interp(om, _omegas_array, _score_norms[eps_idx]) * vec / om
def score_norm(eps):
eps = eps.numpy()
eps_idx = (np.log10(eps) - np.log10(MIN_EPS)) / (np.log10(MAX_EPS) - np.log10(MIN_EPS)) * N_EPS
eps_idx = np.clip(np.around(eps_idx).astype(int), a_min=0, a_max=N_EPS-1)
return torch.from_numpy(_exp_score_norms[eps_idx]).float()