AlienChen's picture
Upload 72 files
3527383 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Union
import torch
from torch import Tensor
@dataclass
class SchedulerOutput:
r"""Represents a sample of a conditional-flow generated probability path.
Attributes:
alpha_t (Tensor): :math:`\alpha_t`, shape (...).
sigma_t (Tensor): :math:`\sigma_t`, shape (...).
d_alpha_t (Tensor): :math:`\frac{\partial}{\partial t}\alpha_t`, shape (...).
d_sigma_t (Tensor): :math:`\frac{\partial}{\partial t}\sigma_t`, shape (...).
"""
alpha_t: Tensor = field(metadata={"help": "alpha_t"})
sigma_t: Tensor = field(metadata={"help": "sigma_t"})
d_alpha_t: Tensor = field(metadata={"help": "Derivative of alpha_t."})
d_sigma_t: Tensor = field(metadata={"help": "Derivative of sigma_t."})
class Scheduler(ABC):
"""Base Scheduler class."""
@abstractmethod
def __call__(self, t: Tensor) -> SchedulerOutput:
r"""
Args:
t (Tensor): times in [0,1], shape (...).
Returns:
SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t`
"""
...
@abstractmethod
def snr_inverse(self, snr: Tensor) -> Tensor:
r"""
Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`.
Args:
snr (Tensor): The signal-to-noise, shape (...)
Returns:
Tensor: t, shape (...)
"""
...
class ConvexScheduler(Scheduler):
@abstractmethod
def __call__(self, t: Tensor) -> SchedulerOutput:
"""Scheduler for convex paths.
Args:
t (Tensor): times in [0,1], shape (...).
Returns:
SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t`
"""
...
@abstractmethod
def kappa_inverse(self, kappa: Tensor) -> Tensor:
"""
Computes :math:`t` from :math:`\kappa_t`.
Args:
kappa (Tensor): :math:`\kappa`, shape (...)
Returns:
Tensor: t, shape (...)
"""
...
def snr_inverse(self, snr: Tensor) -> Tensor:
r"""
Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`.
Args:
snr (Tensor): The signal-to-noise, shape (...)
Returns:
Tensor: t, shape (...)
"""
kappa_t = snr / (1.0 + snr)
return self.kappa_inverse(kappa=kappa_t)
class CondOTScheduler(ConvexScheduler):
"""CondOT Scheduler."""
def __call__(self, t: Tensor) -> SchedulerOutput:
return SchedulerOutput(
alpha_t=t,
sigma_t=1 - t,
d_alpha_t=torch.ones_like(t),
d_sigma_t=-torch.ones_like(t),
)
def kappa_inverse(self, kappa: Tensor) -> Tensor:
return kappa
class PolynomialConvexScheduler(ConvexScheduler):
"""Polynomial Scheduler."""
def __init__(self, n: Union[float, int]) -> None:
assert isinstance(
n, (float, int)
), f"`n` must be a float or int. Got {type(n)=}."
assert n > 0, f"`n` must be positive. Got {n=}."
self.n = n
def __call__(self, t: Tensor) -> SchedulerOutput:
return SchedulerOutput(
alpha_t=t**self.n,
sigma_t=1 - t**self.n,
d_alpha_t=self.n * (t ** (self.n - 1)),
d_sigma_t=-self.n * (t ** (self.n - 1)),
)
def kappa_inverse(self, kappa: Tensor) -> Tensor:
return torch.pow(kappa, 1.0 / self.n)
class VPScheduler(Scheduler):
"""Variance Preserving Scheduler."""
def __init__(self, beta_min: float = 0.1, beta_max: float = 20.0) -> None:
self.beta_min = beta_min
self.beta_max = beta_max
super().__init__()
def __call__(self, t: Tensor) -> SchedulerOutput:
b = self.beta_min
B = self.beta_max
T = 0.5 * (1 - t) ** 2 * (B - b) + (1 - t) * b
dT = -(1 - t) * (B - b) - b
return SchedulerOutput(
alpha_t=torch.exp(-0.5 * T),
sigma_t=torch.sqrt(1 - torch.exp(-T)),
d_alpha_t=-0.5 * dT * torch.exp(-0.5 * T),
d_sigma_t=0.5 * dT * torch.exp(-T) / torch.sqrt(1 - torch.exp(-T)),
)
def snr_inverse(self, snr: Tensor) -> Tensor:
T = -torch.log(snr**2 / (snr**2 + 1))
b = self.beta_min
B = self.beta_max
t = 1 - ((-b + torch.sqrt(b**2 + 2 * (B - b) * T)) / (B - b))
return t
class LinearVPScheduler(Scheduler):
"""Linear Variance Preserving Scheduler."""
def __call__(self, t: Tensor) -> SchedulerOutput:
return SchedulerOutput(
alpha_t=t,
sigma_t=(1 - t**2) ** 0.5,
d_alpha_t=torch.ones_like(t),
d_sigma_t=-t / (1 - t**2) ** 0.5,
)
def snr_inverse(self, snr: Tensor) -> Tensor:
return torch.sqrt(snr**2 / (1 + snr**2))
class CosineScheduler(Scheduler):
"""Cosine Scheduler."""
def __call__(self, t: Tensor) -> SchedulerOutput:
pi = torch.pi
return SchedulerOutput(
alpha_t=torch.sin(pi / 2 * t),
sigma_t=torch.cos(pi / 2 * t),
d_alpha_t=pi / 2 * torch.cos(pi / 2 * t),
d_sigma_t=-pi / 2 * torch.sin(pi / 2 * t),
)
def snr_inverse(self, snr: Tensor) -> Tensor:
return 2.0 * torch.atan(snr) / torch.pi