File size: 4,781 Bytes
3cdaa7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
"""SO(3) diffusion methods."""
import numpy as np
import os
from functools import cached_property
import torch
from scipy.spatial.transform import Rotation
import scipy.linalg
### First define geometric operations on the SO3 manifold
# hat map from vector space R^3 to Lie algebra so(3)
def hat(v):
hat_v = torch.zeros([v.shape[0], 3, 3])
hat_v[:, 0, 1], hat_v[:, 0, 2], hat_v[:, 1, 2] = -v[:, 2], v[:, 1], -v[:, 0]
return hat_v + -hat_v.transpose(2, 1)
# Logarithmic map from SO(3) to R^3 (i.e. rotation vector)
def Log(R): return torch.tensor(Rotation.from_matrix(R.numpy()).as_rotvec())
# logarithmic map from SO(3) to so(3), this is the matrix logarithm
def log(R): return hat(Log(R))
# Exponential map from vector space of so(3) to SO(3), this is the matrix
# exponential combined with the "hat" map
def Exp(A): return torch.tensor(Rotation.from_rotvec(A.numpy()).as_matrix())
# Angle of rotation SO(3) to R^+
def Omega(R): return np.linalg.norm(log(R), axis=[-2, -1])/np.sqrt(2.)
L_default = 2000
def f_igso3(omega, t, L=L_default):
"""Truncated sum of IGSO(3) distribution.
This function approximates the power series in equation 5 of
Leach et al. 2022
This expression diverges from the expression in Leach in that here, sigma =
sqrt(2) * eps, if eps_leach were the scale parameter of the IGSO(3).
With this reparameterization, IGSO(3) agrees with the Brownian motion on
SO(3) with t=sigma^2 when defined for the canonical inner product on SO3,
<u, v>_SO3 = Trace(u v^T)/2
omega: i.e. the angle of rotation associated with rotation matrix
t: variance parameter of IGSO(3), maps onto time in Brownian motion
L: Truncation level
ls = torch.arange(L)[None] # of shape [1, L]
return ((2*ls + 1) * torch.exp(-ls*(ls+1)*t/2) *
torch.sin(omega[:, None]*(ls+1/2)) / torch.sin(omega[:, None]/2)).sum(dim=-1)
def d_logf_d_omega(omega, t, L=L_default):
omega = torch.tensor(omega, requires_grad=True)
log_f = torch.log(f_igso3(omega, t, L))
return torch.autograd.grad(log_f.sum(), omega)[0].numpy()
# IGSO3 density with respect to the volume form on SO(3)
def igso3_density(Rt, t, L=L_default):
return f_igso3(torch.tensor(Omega(Rt)), t, L).numpy()
def igso3_density_angle(omega, t, L=L_default):
return f_igso3(torch.tensor(omega), t, L).numpy()*(1-np.cos(omega))/np.pi
# grad_R log IGSO3(R; I_3, t)
def igso3_score(R, t, L=L_default):
omega = Omega(R)
unit_vector = np.einsum('Nij,Njk->Nik', R, log(R))/omega[:, None, None]
return unit_vector * d_logf_d_omega(omega, t, L)[:, None, None]
def calculate_igso3(*, num_sigma, num_omega, min_sigma, max_sigma):
"""calculate_igso3 pre-computes numerical approximations to the IGSO3 cdfs
and score norms and expected squared score norms.
num_sigma: number of different sigmas for which to compute igso3
num_omega: number of point in the discretization in the angle of
min_sigma, max_sigma: the upper and lower ranges for the angle of
rotation on which to consider the IGSO3 distribution. This cannot
be too low or it will create numerical instability.
# Discretize omegas for calculating CDFs. Skip omega=0.
discrete_omega = np.linspace(0, np.pi, num_omega+1)[1:]
# Exponential noise schedule. This choice is closely tied to the
# scalings used when simulating the reverse time SDE. For each step n,
# discrete_sigma[n] = min_eps^(1-n/num_eps) * max_eps^(n/num_eps)
discrete_sigma = 10 ** np.linspace(np.log10(min_sigma), np.log10(max_sigma), num_sigma + 1)[1:]
# Compute the pdf and cdf values for the marginal distribution of the angle
# of rotation (which is needed for sampling)
pdf_vals = np.asarray(
[igso3_density_angle(discrete_omega, sigma**2) for sigma in discrete_sigma])
cdf_vals = np.asarray(
[pdf.cumsum() / num_omega * np.pi for pdf in pdf_vals])
# Compute the norms of the scores. This are used to scale the rotation axis when
# computing the score as a vector.
score_norm = np.asarray(
[d_logf_d_omega(discrete_omega, sigma**2) for sigma in discrete_sigma])
# Compute the standard deviation of the score norm for each sigma
exp_score_norms = np.sqrt(
score_norm**2 * pdf_vals, axis=1) / np.sum(
pdf_vals, axis=1))
return {
'cdf': cdf_vals,
'score_norm': score_norm,
'exp_score_norms': exp_score_norms,
'discrete_omega': discrete_omega,
'discrete_sigma': discrete_sigma,