|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
from typing import Union |
|
|
|
|
|
import torch |
|
|
|
|
|
from torch import Tensor |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SchedulerOutput: |
|
|
r"""Represents a sample of a conditional-flow generated probability path. |
|
|
|
|
|
Attributes: |
|
|
alpha_t (Tensor): :math:`\alpha_t`, shape (...). |
|
|
sigma_t (Tensor): :math:`\sigma_t`, shape (...). |
|
|
d_alpha_t (Tensor): :math:`\frac{\partial}{\partial t}\alpha_t`, shape (...). |
|
|
d_sigma_t (Tensor): :math:`\frac{\partial}{\partial t}\sigma_t`, shape (...). |
|
|
|
|
|
""" |
|
|
|
|
|
alpha_t: Tensor = field(metadata={"help": "alpha_t"}) |
|
|
sigma_t: Tensor = field(metadata={"help": "sigma_t"}) |
|
|
d_alpha_t: Tensor = field(metadata={"help": "Derivative of alpha_t."}) |
|
|
d_sigma_t: Tensor = field(metadata={"help": "Derivative of sigma_t."}) |
|
|
|
|
|
|
|
|
class Scheduler(ABC): |
|
|
"""Base Scheduler class.""" |
|
|
|
|
|
@abstractmethod |
|
|
def __call__(self, t: Tensor) -> SchedulerOutput: |
|
|
r""" |
|
|
Args: |
|
|
t (Tensor): times in [0,1], shape (...). |
|
|
|
|
|
Returns: |
|
|
SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t` |
|
|
""" |
|
|
... |
|
|
|
|
|
@abstractmethod |
|
|
def snr_inverse(self, snr: Tensor) -> Tensor: |
|
|
r""" |
|
|
Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`. |
|
|
|
|
|
Args: |
|
|
snr (Tensor): The signal-to-noise, shape (...) |
|
|
|
|
|
Returns: |
|
|
Tensor: t, shape (...) |
|
|
""" |
|
|
... |
|
|
|
|
|
|
|
|
class ConvexScheduler(Scheduler): |
|
|
@abstractmethod |
|
|
def __call__(self, t: Tensor) -> SchedulerOutput: |
|
|
"""Scheduler for convex paths. |
|
|
|
|
|
Args: |
|
|
t (Tensor): times in [0,1], shape (...). |
|
|
|
|
|
Returns: |
|
|
SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t` |
|
|
""" |
|
|
... |
|
|
|
|
|
@abstractmethod |
|
|
def kappa_inverse(self, kappa: Tensor) -> Tensor: |
|
|
""" |
|
|
Computes :math:`t` from :math:`\kappa_t`. |
|
|
|
|
|
Args: |
|
|
kappa (Tensor): :math:`\kappa`, shape (...) |
|
|
|
|
|
Returns: |
|
|
Tensor: t, shape (...) |
|
|
""" |
|
|
... |
|
|
|
|
|
def snr_inverse(self, snr: Tensor) -> Tensor: |
|
|
r""" |
|
|
Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`. |
|
|
|
|
|
Args: |
|
|
snr (Tensor): The signal-to-noise, shape (...) |
|
|
|
|
|
Returns: |
|
|
Tensor: t, shape (...) |
|
|
""" |
|
|
kappa_t = snr / (1.0 + snr) |
|
|
|
|
|
return self.kappa_inverse(kappa=kappa_t) |
|
|
|
|
|
|
|
|
class CondOTScheduler(ConvexScheduler): |
|
|
"""CondOT Scheduler.""" |
|
|
|
|
|
def __call__(self, t: Tensor) -> SchedulerOutput: |
|
|
return SchedulerOutput( |
|
|
alpha_t=t, |
|
|
sigma_t=1 - t, |
|
|
d_alpha_t=torch.ones_like(t), |
|
|
d_sigma_t=-torch.ones_like(t), |
|
|
) |
|
|
|
|
|
def kappa_inverse(self, kappa: Tensor) -> Tensor: |
|
|
return kappa |
|
|
|
|
|
|
|
|
class PolynomialConvexScheduler(ConvexScheduler): |
|
|
"""Polynomial Scheduler.""" |
|
|
|
|
|
def __init__(self, n: Union[float, int]) -> None: |
|
|
assert isinstance( |
|
|
n, (float, int) |
|
|
), f"`n` must be a float or int. Got {type(n)=}." |
|
|
assert n > 0, f"`n` must be positive. Got {n=}." |
|
|
|
|
|
self.n = n |
|
|
|
|
|
def __call__(self, t: Tensor) -> SchedulerOutput: |
|
|
return SchedulerOutput( |
|
|
alpha_t=t**self.n, |
|
|
sigma_t=1 - t**self.n, |
|
|
d_alpha_t=self.n * (t ** (self.n - 1)), |
|
|
d_sigma_t=-self.n * (t ** (self.n - 1)), |
|
|
) |
|
|
|
|
|
def kappa_inverse(self, kappa: Tensor) -> Tensor: |
|
|
return torch.pow(kappa, 1.0 / self.n) |
|
|
|
|
|
|
|
|
class VPScheduler(Scheduler): |
|
|
"""Variance Preserving Scheduler.""" |
|
|
|
|
|
def __init__(self, beta_min: float = 0.1, beta_max: float = 20.0) -> None: |
|
|
self.beta_min = beta_min |
|
|
self.beta_max = beta_max |
|
|
super().__init__() |
|
|
|
|
|
def __call__(self, t: Tensor) -> SchedulerOutput: |
|
|
b = self.beta_min |
|
|
B = self.beta_max |
|
|
T = 0.5 * (1 - t) ** 2 * (B - b) + (1 - t) * b |
|
|
dT = -(1 - t) * (B - b) - b |
|
|
|
|
|
return SchedulerOutput( |
|
|
alpha_t=torch.exp(-0.5 * T), |
|
|
sigma_t=torch.sqrt(1 - torch.exp(-T)), |
|
|
d_alpha_t=-0.5 * dT * torch.exp(-0.5 * T), |
|
|
d_sigma_t=0.5 * dT * torch.exp(-T) / torch.sqrt(1 - torch.exp(-T)), |
|
|
) |
|
|
|
|
|
def snr_inverse(self, snr: Tensor) -> Tensor: |
|
|
T = -torch.log(snr**2 / (snr**2 + 1)) |
|
|
b = self.beta_min |
|
|
B = self.beta_max |
|
|
t = 1 - ((-b + torch.sqrt(b**2 + 2 * (B - b) * T)) / (B - b)) |
|
|
return t |
|
|
|
|
|
|
|
|
class LinearVPScheduler(Scheduler): |
|
|
"""Linear Variance Preserving Scheduler.""" |
|
|
|
|
|
def __call__(self, t: Tensor) -> SchedulerOutput: |
|
|
return SchedulerOutput( |
|
|
alpha_t=t, |
|
|
sigma_t=(1 - t**2) ** 0.5, |
|
|
d_alpha_t=torch.ones_like(t), |
|
|
d_sigma_t=-t / (1 - t**2) ** 0.5, |
|
|
) |
|
|
|
|
|
def snr_inverse(self, snr: Tensor) -> Tensor: |
|
|
return torch.sqrt(snr**2 / (1 + snr**2)) |
|
|
|
|
|
|
|
|
class CosineScheduler(Scheduler): |
|
|
"""Cosine Scheduler.""" |
|
|
|
|
|
def __call__(self, t: Tensor) -> SchedulerOutput: |
|
|
pi = torch.pi |
|
|
return SchedulerOutput( |
|
|
alpha_t=torch.sin(pi / 2 * t), |
|
|
sigma_t=torch.cos(pi / 2 * t), |
|
|
d_alpha_t=pi / 2 * torch.cos(pi / 2 * t), |
|
|
d_sigma_t=-pi / 2 * torch.sin(pi / 2 * t), |
|
|
) |
|
|
|
|
|
def snr_inverse(self, snr: Tensor) -> Tensor: |
|
|
return 2.0 * torch.atan(snr) / torch.pi |
|
|
|