Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch as th | |
import torch.nn as nn | |
from torchdiffeq import odeint | |
from functools import partial | |
from tqdm import tqdm | |
class sde: | |
"""SDE solver class""" | |
def __init__( | |
self, | |
drift, | |
diffusion, | |
*, | |
t0, | |
t1, | |
num_steps, | |
sampler_type, | |
): | |
assert t0 < t1, "SDE sampler has to be in forward time" | |
self.num_timesteps = num_steps | |
self.t = th.linspace(t0, t1, num_steps) | |
self.dt = self.t[1] - self.t[0] | |
self.drift = drift | |
self.diffusion = diffusion | |
self.sampler_type = sampler_type | |
def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs): | |
w_cur = th.randn(x.size()).to(x) | |
t = th.ones(x.size(0)).to(x) * t | |
dw = w_cur * th.sqrt(self.dt) | |
drift = self.drift(x, t, model, **model_kwargs) | |
diffusion = self.diffusion(x, t) | |
mean_x = x + drift * self.dt | |
x = mean_x + th.sqrt(2 * diffusion) * dw | |
return x, mean_x | |
def __Heun_step(self, x, _, t, model, **model_kwargs): | |
w_cur = th.randn(x.size()).to(x) | |
dw = w_cur * th.sqrt(self.dt) | |
t_cur = th.ones(x.size(0)).to(x) * t | |
diffusion = self.diffusion(x, t_cur) | |
xhat = x + th.sqrt(2 * diffusion) * dw | |
K1 = self.drift(xhat, t_cur, model, **model_kwargs) | |
xp = xhat + self.dt * K1 | |
K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs) | |
return ( | |
xhat + 0.5 * self.dt * (K1 + K2), | |
xhat, | |
) # at last time point we do not perform the heun step | |
def __forward_fn(self): | |
"""TODO: generalize here by adding all private functions ending with steps to it""" | |
sampler_dict = { | |
"Euler": self.__Euler_Maruyama_step, | |
"Heun": self.__Heun_step, | |
} | |
try: | |
sampler = sampler_dict[self.sampler_type] | |
except: | |
raise NotImplementedError("Smapler type not implemented.") | |
return sampler | |
def sample(self, init, model, **model_kwargs): | |
"""forward loop of sde""" | |
x = init | |
mean_x = init | |
samples = [] | |
sampler = self.__forward_fn() | |
for ti in self.t[:-1]: | |
with th.no_grad(): | |
x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs) | |
samples.append(x) | |
return samples | |
class ode: | |
"""ODE solver class""" | |
def __init__( | |
self, | |
drift, | |
*, | |
t0, | |
t1, | |
sampler_type, | |
num_steps, | |
atol, | |
rtol, | |
time_shifting_factor=None, | |
): | |
assert t0 < t1, "ODE sampler has to be in forward time" | |
self.drift = drift | |
self.t = th.linspace(t0, t1, num_steps) | |
if time_shifting_factor: | |
self.t = self.t / ( | |
self.t + time_shifting_factor - time_shifting_factor * self.t | |
) | |
self.atol = atol | |
self.rtol = rtol | |
self.sampler_type = sampler_type | |
def sample(self, x, model, **model_kwargs): | |
device = x[0].device if isinstance(x, tuple) else x.device | |
def _fn(t, x): | |
t = ( | |
th.ones(x[0].size(0)).to(device) * t | |
if isinstance(x, tuple) | |
else th.ones(x.size(0)).to(device) * t | |
) | |
model_output = self.drift(x, t, model, **model_kwargs) | |
return model_output | |
t = self.t.to(device) | |
atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol] | |
rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol] | |
samples = odeint(_fn, x, t, method=self.sampler_type, atol=atol, rtol=rtol) | |
return samples | |