|
""" |
|
Abstract SDE classes, Reverse SDE, and VE/VP SDEs. |
|
|
|
Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py |
|
""" |
|
import abc |
|
import warnings |
|
|
|
import numpy as np |
|
from sgmse.util.tensors import batch_broadcast |
|
import torch |
|
|
|
from sgmse.util.registry import Registry |
|
|
|
|
|
SDERegistry = Registry("SDE") |
|
|
|
|
|
class SDE(abc.ABC): |
|
"""SDE abstract class. Functions are designed for a mini-batch of inputs.""" |
|
|
|
def __init__(self, N): |
|
"""Construct an SDE. |
|
|
|
Args: |
|
N: number of discretization time steps. |
|
""" |
|
super().__init__() |
|
self.N = N |
|
|
|
@property |
|
@abc.abstractmethod |
|
def T(self): |
|
"""End time of the SDE.""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def sde(self, x, t, *args): |
|
pass |
|
|
|
@abc.abstractmethod |
|
def marginal_prob(self, x, t, *args): |
|
"""Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$.""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def prior_sampling(self, shape, *args): |
|
"""Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`.""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def prior_logp(self, z): |
|
"""Compute log-density of the prior distribution. |
|
|
|
Useful for computing the log-likelihood via probability flow ODE. |
|
|
|
Args: |
|
z: latent code |
|
Returns: |
|
log probability density |
|
""" |
|
pass |
|
|
|
@staticmethod |
|
@abc.abstractmethod |
|
def add_argparse_args(parent_parser): |
|
""" |
|
Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser. |
|
""" |
|
pass |
|
|
|
def discretize(self, x, t, y, stepsize): |
|
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. |
|
|
|
Useful for reverse diffusion sampling and probabiliy flow sampling. |
|
Defaults to Euler-Maruyama discretization. |
|
|
|
Args: |
|
x: a torch tensor |
|
t: a torch float representing the time step (from 0 to `self.T`) |
|
|
|
Returns: |
|
f, G |
|
""" |
|
dt = stepsize |
|
drift, diffusion = self.sde(x, t, y) |
|
f = drift * dt |
|
G = diffusion * torch.sqrt(dt) |
|
return f, G |
|
|
|
def reverse(oself, score_model, probability_flow=False): |
|
"""Create the reverse-time SDE/ODE. |
|
|
|
Args: |
|
score_model: A function that takes x, t and y and returns the score. |
|
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling. |
|
""" |
|
N = oself.N |
|
T = oself.T |
|
sde_fn = oself.sde |
|
discretize_fn = oself.discretize |
|
|
|
|
|
class RSDE(oself.__class__): |
|
def __init__(self): |
|
self.N = N |
|
self.probability_flow = probability_flow |
|
|
|
@property |
|
def T(self): |
|
return T |
|
|
|
def sde(self, x, t, *args): |
|
"""Create the drift and diffusion functions for the reverse SDE/ODE.""" |
|
rsde_parts = self.rsde_parts(x, t, *args) |
|
total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"] |
|
return total_drift, diffusion |
|
|
|
def rsde_parts(self, x, t, *args): |
|
sde_drift, sde_diffusion = sde_fn(x, t, *args) |
|
score = score_model(x, t, *args) |
|
score_drift = -sde_diffusion[:, None, None, None]**2 * score * (0.5 if self.probability_flow else 1.) |
|
diffusion = torch.zeros_like(sde_diffusion) if self.probability_flow else sde_diffusion |
|
total_drift = sde_drift + score_drift |
|
return { |
|
'total_drift': total_drift, 'diffusion': diffusion, 'sde_drift': sde_drift, |
|
'sde_diffusion': sde_diffusion, 'score_drift': score_drift, 'score': score, |
|
} |
|
|
|
def discretize(self, x, t, y, stepsize): |
|
"""Create discretized iteration rules for the reverse diffusion sampler.""" |
|
f, G = discretize_fn(x, t, y, stepsize) |
|
rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y) * (0.5 if self.probability_flow else 1.) |
|
rev_G = torch.zeros_like(G) if self.probability_flow else G |
|
return rev_f, rev_G |
|
|
|
return RSDE() |
|
|
|
@abc.abstractmethod |
|
def copy(self): |
|
pass |
|
|
|
|
|
@SDERegistry.register("ouve") |
|
class OUVESDE(SDE): |
|
@staticmethod |
|
def add_argparse_args(parser): |
|
parser.add_argument("--sde-n", type=int, default=1000, help="The number of timesteps in the SDE discretization. 30 by default") |
|
parser.add_argument("--theta", type=float, default=1.5, help="The constant stiffness of the Ornstein-Uhlenbeck process. 1.5 by default.") |
|
parser.add_argument("--sigma-min", type=float, default=0.05, help="The minimum sigma to use. 0.05 by default.") |
|
parser.add_argument("--sigma-max", type=float, default=0.5, help="The maximum sigma to use. 0.5 by default.") |
|
return parser |
|
|
|
def __init__(self, theta, sigma_min, sigma_max, N=1000, **ignored_kwargs): |
|
"""Construct an Ornstein-Uhlenbeck Variance Exploding SDE. |
|
|
|
Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument |
|
to the methods which require it (e.g., `sde` or `marginal_prob`). |
|
|
|
dx = -theta (y-x) dt + sigma(t) dw |
|
|
|
with |
|
|
|
sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min)) |
|
|
|
Args: |
|
theta: stiffness parameter. |
|
sigma_min: smallest sigma. |
|
sigma_max: largest sigma. |
|
N: number of discretization steps |
|
""" |
|
super().__init__(N) |
|
self.theta = theta |
|
self.sigma_min = sigma_min |
|
self.sigma_max = sigma_max |
|
self.logsig = np.log(self.sigma_max / self.sigma_min) |
|
self.N = N |
|
|
|
def copy(self): |
|
return OUVESDE(self.theta, self.sigma_min, self.sigma_max, N=self.N) |
|
|
|
@property |
|
def T(self): |
|
return 1 |
|
|
|
def sde(self, x, t, y): |
|
drift = self.theta * (y - x) |
|
|
|
|
|
|
|
|
|
sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t |
|
diffusion = sigma * np.sqrt(2 * self.logsig) |
|
return drift, diffusion |
|
|
|
def _mean(self, x0, t, y): |
|
theta = self.theta |
|
exp_interp = torch.exp(-theta * t)[:, None, None, None] |
|
return exp_interp * x0 + (1 - exp_interp) * y |
|
|
|
def alpha(self, t): |
|
return torch.exp(-self.theta * t) |
|
|
|
def _std(self, t): |
|
|
|
sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig |
|
|
|
return torch.sqrt( |
|
( |
|
sigma_min**2 |
|
* torch.exp(-2 * theta * t) |
|
* (torch.exp(2 * (theta + logsig) * t) - 1) |
|
* logsig |
|
) |
|
/ |
|
(theta + logsig) |
|
) |
|
|
|
def marginal_prob(self, x0, t, y): |
|
return self._mean(x0, t, y), self._std(t) |
|
|
|
def prior_sampling(self, shape, y): |
|
if shape != y.shape: |
|
warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.") |
|
std = self._std(torch.ones((y.shape[0],), device=y.device)) |
|
x_T = y + torch.randn_like(y) * std[:, None, None, None] |
|
return x_T |
|
|
|
def prior_logp(self, z): |
|
raise NotImplementedError("prior_logp for OU SDE not yet implemented!") |
|
|
|
|
|
@SDERegistry.register("ouvp") |
|
class OUVPSDE(SDE): |
|
|
|
@staticmethod |
|
def add_argparse_args(parser): |
|
parser.add_argument("--sde-n", type=int, default=1000, |
|
help="The number of timesteps in the SDE discretization. 1000 by default") |
|
parser.add_argument("--beta-min", type=float, required=True, |
|
help="The minimum beta to use.") |
|
parser.add_argument("--beta-max", type=float, required=True, |
|
help="The maximum beta to use.") |
|
parser.add_argument("--stiffness", type=float, default=1, |
|
help="The stiffness factor for the drift, to be multiplied by 0.5*beta(t). 1 by default.") |
|
return parser |
|
|
|
def __init__(self, beta_min, beta_max, stiffness=1, N=1000, **ignored_kwargs): |
|
""" |
|
!!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!! |
|
|
|
Construct an Ornstein-Uhlenbeck Variance Preserving SDE: |
|
|
|
dx = -1/2 * beta(t) * stiffness * (y-x) dt + sqrt(beta(t)) * dw |
|
|
|
with |
|
|
|
beta(t) = beta_min + t(beta_max - beta_min) |
|
|
|
Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument |
|
to the methods which require it (e.g., `sde` or `marginal_prob`). |
|
|
|
Args: |
|
beta_min: smallest sigma. |
|
beta_max: largest sigma. |
|
stiffness: stiffness factor of the drift. 1 by default. |
|
N: number of discretization steps |
|
""" |
|
super().__init__(N) |
|
self.beta_min = beta_min |
|
self.beta_max = beta_max |
|
self.stiffness = stiffness |
|
self.N = N |
|
|
|
def copy(self): |
|
return OUVPSDE(self.beta_min, self.beta_max, self.stiffness, N=self.N) |
|
|
|
@property |
|
def T(self): |
|
return 1 |
|
|
|
def _beta(self, t): |
|
return self.beta_min + t * (self.beta_max - self.beta_min) |
|
|
|
def sde(self, x, t, y): |
|
drift = 0.5 * self.stiffness * batch_broadcast(self._beta(t), y) * (y - x) |
|
diffusion = torch.sqrt(self._beta(t)) |
|
return drift, diffusion |
|
|
|
def _mean(self, x0, t, y): |
|
b0, b1, s = self.beta_min, self.beta_max, self.stiffness |
|
x0y_fac = torch.exp(-0.25 * s * t * (t * (b1-b0) + 2 * b0))[:, None, None, None] |
|
return y + x0y_fac * (x0 - y) |
|
|
|
def _std(self, t): |
|
b0, b1, s = self.beta_min, self.beta_max, self.stiffness |
|
return (1 - torch.exp(-0.5 * s * t * (t * (b1-b0) + 2 * b0))) / s |
|
|
|
def marginal_prob(self, x0, t, y): |
|
return self._mean(x0, t, y), self._std(t) |
|
|
|
def prior_sampling(self, shape, y): |
|
if shape != y.shape: |
|
warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.") |
|
std = self._std(torch.ones((y.shape[0],), device=y.device)) |
|
x_T = y + torch.randn_like(y) * std[:, None, None, None] |
|
return x_T |
|
|
|
def prior_logp(self, z): |
|
raise NotImplementedError("prior_logp for OU SDE not yet implemented!") |
|
|