Spaces:
Runtime error
Runtime error
File size: 882 Bytes
bfd34e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import torch
def linear(n_timestep = 1000, start = 1e-4, end = 2e-2):
return Schedule(torch.linspace(start ** 0.5, end ** 0.5, n_timestep, dtype = torch.float64) ** 2)
class Schedule:
def __init__(self, betas):
self.betas = betas
self._alphas = 1 - betas
self.alphas = torch.cumprod(self._alphas, 0)
self.one_minus_alphas = 1 - self.alphas
self.sqrt_alphas = torch.sqrt(self.alphas)
self.sqrt_one_minus_alphas = torch.sqrt(1 - self.alphas)
self.sqrt_noise_signal_ratio = self.sqrt_one_minus_alphas / self.sqrt_alphas
self.noise_signal_ratio = (1 - self.alphas) / self.alphas
def range(self, dt):
return range(len(self.betas)-1, 0, -dt)
def sigma(self, t, dt):
return torch.sqrt((1 - self._alphas[t - dt]) / (1 - self._alphas[t]) * (1 - self._alphas[t] / self._alphas[t - dt]))
|