Spaces:
Sleeping
Sleeping
import torch | |
import matplotlib.pyplot as plt | |
def linear_beta_schedule(timesteps): | |
beta_start = 1e-4 | |
beta_end = 0.02 | |
return torch.linspace(beta_start, beta_end, timesteps) | |
def quadratic_beta_schedule(timesteps): | |
beta_start = 1e-4 | |
beta_end = 0.02 | |
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2 | |
def sigmoid_beta_schedule(timesteps): | |
beta_start = 1e-4 | |
beta_end = 0.02 | |
s = torch.tensor(6.0) | |
betas = torch.sigmoid(torch.linspace(-s, s, timesteps)) | |
return betas * (beta_end - beta_start) + beta_start | |
def plot_beta_schedules(timesteps): | |
plt.figure(figsize=(10, 6)) | |
plt.plot(linear_beta_schedule(timesteps), label='Linear') | |
plt.plot(quadratic_beta_schedule(timesteps), label='Quadratic') | |
plt.plot(sigmoid_beta_schedule(timesteps), label='Sigmoid') | |
plt.xlabel('Timestep') | |
plt.ylabel('Beta') | |
plt.legend() | |
plt.show() | |
if __name__ == '__main__': | |
timesteps = 1000 | |
plot_beta_schedules(timesteps) | |