|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from torch import Tensor |
|
|
|
|
|
from flow_matching.utils.manifolds import Manifold |
|
|
|
|
|
|
|
|
class Sphere(Manifold): |
|
|
"""Represents a hyperpshere in :math:`R^D`. Isometric to the product of 1-D spheres.""" |
|
|
|
|
|
EPS = {torch.float32: 1e-4, torch.float64: 1e-7} |
|
|
|
|
|
def expmap(self, x: Tensor, u: Tensor) -> Tensor: |
|
|
norm_u = u.norm(dim=-1, keepdim=True) |
|
|
exp = x * torch.cos(norm_u) + u * torch.sin(norm_u) / norm_u |
|
|
retr = self.projx(x + u) |
|
|
cond = norm_u > self.EPS[norm_u.dtype] |
|
|
|
|
|
return torch.where(cond, exp, retr) |
|
|
|
|
|
def logmap(self, x: Tensor, y: Tensor) -> Tensor: |
|
|
u = self.proju(x, y - x) |
|
|
dist = self.dist(x, y, keepdim=True) |
|
|
cond = dist.gt(self.EPS[x.dtype]) |
|
|
result = torch.where( |
|
|
cond, |
|
|
u * dist / u.norm(dim=-1, keepdim=True).clamp_min(self.EPS[x.dtype]), |
|
|
u, |
|
|
) |
|
|
return result |
|
|
|
|
|
def projx(self, x: Tensor) -> Tensor: |
|
|
return x / x.norm(dim=-1, keepdim=True) |
|
|
|
|
|
def proju(self, x: Tensor, u: Tensor) -> Tensor: |
|
|
return u - (x * u).sum(dim=-1, keepdim=True) * x |
|
|
|
|
|
def dist(self, x: Tensor, y: Tensor, *, keepdim=False) -> Tensor: |
|
|
inner = (x * y).sum(-1, keepdim=keepdim) |
|
|
return torch.acos(inner) |
|
|
|