pan-yl's picture
update file
2a00960
raw
history blame
1.04 kB
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
from scepter.modules.model.registry import NOISE_SCHEDULERS
from scepter.modules.model.diffusion.schedules import BaseNoiseScheduler
@NOISE_SCHEDULERS.register_class()
class LinearScheduler(BaseNoiseScheduler):
para_dict = {}
def init_params(self):
super().init_params()
self.beta_min = self.cfg.get('BETA_MIN', 0.00085)
self.beta_max = self.cfg.get('BETA_MAX', 0.012)
def betas_to_sigmas(self, betas):
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
def get_schedule(self):
betas = torch.linspace(self.beta_min,
self.beta_max,
self.num_timesteps,
dtype=torch.float32)
sigmas = self.betas_to_sigmas(betas)
self._sigmas = sigmas
self._betas = betas
self._alphas = torch.sqrt(1 - sigmas**2)
self._timesteps = torch.arange(len(sigmas), dtype=torch.float32)