import torch from utils.manifolds import Sphere, geodesic from torch.func import vjp, jvp, vmap, jacrev class DDPMLoss: def __init__( self, scheduler, cond_drop_rate=0.0, conditioning_key="label", ): self.scheduler = scheduler self.cond_drop_rate = cond_drop_rate self.conditioning_key = conditioning_key def __call__(self, preconditioning, network, batch, generator=None): x_0 = batch["x_0"] batch_size = x_0.shape[0] device = x_0.device t = torch.rand(batch_size, device=device, dtype=x_0.dtype, generator=generator) gamma = self.scheduler(t).unsqueeze(-1) n = torch.randn(x_0.shape, dtype=x_0.dtype, device=device, generator=generator) y = torch.sqrt(gamma) * x_0 + torch.sqrt(1 - gamma) * n batch["y"] = y conditioning = batch[self.conditioning_key] if conditioning is not None and self.cond_drop_rate > 0: drop_mask = ( torch.rand(batch_size, device=device, generator=generator) < self.cond_drop_rate ) conditioning[drop_mask] = torch.zeros_like(conditioning[drop_mask]) batch[self.conditioning_key] = conditioning.detach() batch["gamma"] = gamma.squeeze(-1).squeeze(-1).squeeze(-1) D_n = preconditioning(network, batch) loss = (D_n - n) ** 2 return loss class FlowMatchingLoss: def __init__( self, scheduler, cond_drop_rate=0.0, conditioning_key="label", ): self.scheduler = scheduler self.cond_drop_rate = cond_drop_rate self.conditioning_key = conditioning_key def __call__(self, preconditioning, network, batch, generator=None): x_0 = batch["x_0"] batch_size = x_0.shape[0] device = x_0.device t = torch.rand(batch_size, device=device, dtype=x_0.dtype, generator=generator) gamma = self.scheduler(t).unsqueeze(-1) n = torch.randn(x_0.shape, dtype=x_0.dtype, device=device, generator=generator) y = gamma * x_0 + (1 - gamma) * n batch["y"] = y conditioning = batch[self.conditioning_key] if conditioning is not None and self.cond_drop_rate > 0: drop_mask = ( torch.rand(batch_size, device=device, generator=generator) < self.cond_drop_rate ) conditioning[drop_mask] = torch.zeros_like(conditioning[drop_mask]) batch[self.conditioning_key] = conditioning.detach() batch["gamma"] = gamma.squeeze(-1).squeeze(-1).squeeze(-1) D_n = preconditioning(network, batch) loss = (D_n - (x_0 - n)) ** 2 return loss class RiemannianFlowMatchingLoss: def __init__( self, scheduler, cond_drop_rate=0.0, conditioning_key="label", ): self.scheduler = scheduler self.cond_drop_rate = cond_drop_rate self.conditioning_key = conditioning_key self.manifold = Sphere() self.manifold_dim = 3 def __call__(self, preconditioning, network, batch, generator=None): x_1 = batch["x_0"] batch_size = x_1.shape[0] device = x_1.device t = torch.rand(batch_size, device=device, dtype=x_1.dtype, generator=generator) gamma = self.scheduler(t).unsqueeze(-1) x_0 = self.manifold.random_base(x_1.shape[0], self.manifold_dim).to(x_1) def cond_u(x0, x1, t): path = geodesic(self.manifold, x0, x1) x_t, u_t = jvp(path, (t,), (torch.ones_like(t).to(t),)) return x_t, u_t y, u_t = vmap(cond_u)(x_0, x_1, gamma) y = y.reshape(batch_size, self.manifold_dim) u_t = u_t.reshape(batch_size, self.manifold_dim) batch["y"] = y conditioning = batch[self.conditioning_key] if conditioning is not None and self.cond_drop_rate > 0: drop_mask = ( torch.rand(batch_size, device=device, generator=generator) < self.cond_drop_rate ) conditioning[drop_mask] = torch.zeros_like(conditioning[drop_mask]) batch[self.conditioning_key] = conditioning.detach() batch["gamma"] = gamma.squeeze(-1).squeeze(-1).squeeze(-1) D_n = preconditioning(network, batch) diff = D_n - u_t loss = self.manifold.inner(y, diff, diff).mean() / self.manifold_dim return loss class VonFisherLoss: def __init__(self, dim=3): self.dim = dim def __call__(self, preconditioning, network, batch, generator=None): x = batch["x_0"] mu, kappa = preconditioning(network, batch) loss = ( torch.log((kappa + 1e-8)) - torch.log(torch.tensor(4 * torch.pi, dtype=kappa.dtype)) - log_sinh(kappa) + kappa * (mu * x).sum(dim=-1, keepdim=True) ) return -loss class VonFisherMixtureLoss: def __init__(self, dim=3): self.dim = dim def __call__(self, preconditioning, network, batch, generator=None): x = batch["x_0"] mu_mixture, kappa_mixture, weights = preconditioning(network, batch) loss = 0 for i in range(mu_mixture.shape[1]): mu = mu_mixture[:, i] kappa = kappa_mixture[:, i].unsqueeze(1) loss += weights[:, i].unsqueeze(1) * ( kappa * torch.exp(kappa * ((mu * x).sum(dim=-1, keepdim=True) - 1)) / (1e-8 + 2 * torch.pi * (1 - torch.exp(-2 * kappa))) ) return -torch.log(loss) def log_sinh(x): return x + torch.log(1e-8 + (1 - torch.exp(-2 * x)) / 2)