pan-yl's picture
update file
2a00960
raw
history blame
2.84 kB
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
from scepter.modules.model.registry import DIFFUSION_SAMPLERS
from scepter.modules.model.diffusion.samplers import BaseDiffusionSampler
from scepter.modules.model.diffusion.util import _i
def _i(tensor, t, x):
"""
Index tensor using t and format the output according to x.
"""
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
if isinstance(t, torch.Tensor):
t = t.to(tensor.device)
return tensor[t].view(shape).to(x.device)
@DIFFUSION_SAMPLERS.register_class('ddim')
class DDIMSampler(BaseDiffusionSampler):
def init_params(self):
super().init_params()
self.eta = self.cfg.get('ETA', 0.)
self.discretization_type = self.cfg.get('DISCRETIZATION_TYPE',
'trailing')
def preprare_sampler(self,
noise,
steps=20,
scheduler_ins=None,
prediction_type='',
sigmas=None,
betas=None,
alphas=None,
callback_fn=None,
**kwargs):
output = super().preprare_sampler(noise, steps, scheduler_ins,
prediction_type, sigmas, betas,
alphas, callback_fn, **kwargs)
sigmas = output.sigmas
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
sigmas_vp = (sigmas**2 / (1 + sigmas**2))**0.5
sigmas_vp[sigmas == float('inf')] = 1.
output.add_custom_field('sigmas_vp', sigmas_vp)
return output
def step(self, sampler_output):
x_t = sampler_output.x_t
step = sampler_output.step
t = sampler_output.ts[step]
sigmas_vp = sampler_output.sigmas_vp.to(x_t.device)
alpha_init = _i(sampler_output.alphas_init, step, x_t[:1])
sigma_init = _i(sampler_output.sigmas_init, step, x_t[:1])
x = sampler_output.callback_fn(x_t, t, sigma_init, alpha_init)
noise_factor = self.eta * (sigmas_vp[step + 1]**2 /
sigmas_vp[step]**2 *
(1 - (1 - sigmas_vp[step]**2) /
(1 - sigmas_vp[step + 1]**2)))
d = (x_t - (1 - sigmas_vp[step]**2)**0.5 * x) / sigmas_vp[step]
x = (1 - sigmas_vp[step + 1] ** 2) ** 0.5 * x + \
(sigmas_vp[step + 1] ** 2 - noise_factor ** 2) ** 0.5 * d
sampler_output.x_0 = x
if sigmas_vp[step + 1] > 0:
x += noise_factor * torch.randn_like(x)
sampler_output.x_t = x
sampler_output.step += 1
sampler_output.msg = f'step {step}'
return sampler_output