|
"""SO(3) diffusion methods.""" |
|
import numpy as np |
|
import os |
|
from functools import cached_property |
|
import torch |
|
from scipy.spatial.transform import Rotation |
|
import scipy.linalg |
|
|
|
|
|
|
|
|
|
|
|
def hat(v): |
|
hat_v = torch.zeros([v.shape[0], 3, 3]) |
|
hat_v[:, 0, 1], hat_v[:, 0, 2], hat_v[:, 1, 2] = -v[:, 2], v[:, 1], -v[:, 0] |
|
return hat_v + -hat_v.transpose(2, 1) |
|
|
|
|
|
def Log(R): return torch.tensor(Rotation.from_matrix(R.numpy()).as_rotvec()) |
|
|
|
|
|
def log(R): return hat(Log(R)) |
|
|
|
|
|
|
|
def Exp(A): return torch.tensor(Rotation.from_rotvec(A.numpy()).as_matrix()) |
|
|
|
|
|
def Omega(R): return np.linalg.norm(log(R), axis=[-2, -1])/np.sqrt(2.) |
|
|
|
L_default = 2000 |
|
def f_igso3(omega, t, L=L_default): |
|
"""Truncated sum of IGSO(3) distribution. |
|
|
|
This function approximates the power series in equation 5 of |
|
"DENOISING DIFFUSION PROBABILISTIC MODELS ON SO(3) FOR ROTATIONAL |
|
ALIGNMENT" |
|
Leach et al. 2022 |
|
|
|
This expression diverges from the expression in Leach in that here, sigma = |
|
sqrt(2) * eps, if eps_leach were the scale parameter of the IGSO(3). |
|
|
|
With this reparameterization, IGSO(3) agrees with the Brownian motion on |
|
SO(3) with t=sigma^2 when defined for the canonical inner product on SO3, |
|
<u, v>_SO3 = Trace(u v^T)/2 |
|
|
|
Args: |
|
omega: i.e. the angle of rotation associated with rotation matrix |
|
t: variance parameter of IGSO(3), maps onto time in Brownian motion |
|
L: Truncation level |
|
""" |
|
ls = torch.arange(L)[None] |
|
return ((2*ls + 1) * torch.exp(-ls*(ls+1)*t/2) * |
|
torch.sin(omega[:, None]*(ls+1/2)) / torch.sin(omega[:, None]/2)).sum(dim=-1) |
|
|
|
def d_logf_d_omega(omega, t, L=L_default): |
|
omega = torch.tensor(omega, requires_grad=True) |
|
log_f = torch.log(f_igso3(omega, t, L)) |
|
return torch.autograd.grad(log_f.sum(), omega)[0].numpy() |
|
|
|
|
|
def igso3_density(Rt, t, L=L_default): |
|
return f_igso3(torch.tensor(Omega(Rt)), t, L).numpy() |
|
|
|
def igso3_density_angle(omega, t, L=L_default): |
|
return f_igso3(torch.tensor(omega), t, L).numpy()*(1-np.cos(omega))/np.pi |
|
|
|
|
|
def igso3_score(R, t, L=L_default): |
|
omega = Omega(R) |
|
unit_vector = np.einsum('Nij,Njk->Nik', R, log(R))/omega[:, None, None] |
|
return unit_vector * d_logf_d_omega(omega, t, L)[:, None, None] |
|
|
|
def calculate_igso3(*, num_sigma, num_omega, min_sigma, max_sigma): |
|
"""calculate_igso3 pre-computes numerical approximations to the IGSO3 cdfs |
|
and score norms and expected squared score norms. |
|
|
|
Args: |
|
num_sigma: number of different sigmas for which to compute igso3 |
|
quantities. |
|
num_omega: number of point in the discretization in the angle of |
|
rotation. |
|
min_sigma, max_sigma: the upper and lower ranges for the angle of |
|
rotation on which to consider the IGSO3 distribution. This cannot |
|
be too low or it will create numerical instability. |
|
""" |
|
|
|
discrete_omega = np.linspace(0, np.pi, num_omega+1)[1:] |
|
|
|
|
|
|
|
|
|
discrete_sigma = 10 ** np.linspace(np.log10(min_sigma), np.log10(max_sigma), num_sigma + 1)[1:] |
|
|
|
|
|
|
|
pdf_vals = np.asarray( |
|
[igso3_density_angle(discrete_omega, sigma**2) for sigma in discrete_sigma]) |
|
cdf_vals = np.asarray( |
|
[pdf.cumsum() / num_omega * np.pi for pdf in pdf_vals]) |
|
|
|
|
|
|
|
score_norm = np.asarray( |
|
[d_logf_d_omega(discrete_omega, sigma**2) for sigma in discrete_sigma]) |
|
|
|
|
|
exp_score_norms = np.sqrt( |
|
np.sum( |
|
score_norm**2 * pdf_vals, axis=1) / np.sum( |
|
pdf_vals, axis=1)) |
|
return { |
|
'cdf': cdf_vals, |
|
'score_norm': score_norm, |
|
'exp_score_norms': exp_score_norms, |
|
'discrete_omega': discrete_omega, |
|
'discrete_sigma': discrete_sigma, |
|
} |
|
|