EscherNet / 6DoF /diffusers /schedulers /scheduling_unipc_multistep.py
kxhit
update
5f093a6
raw
history blame
30.5 kB
# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a
corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders. UniPC is
by desinged model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional sampling. It can
also be applied to both noise prediction model and data prediction model. The corrector UniC can be also applied
after any off-the-shelf solvers to increase the order of accuracy.
For more details, see the original paper: https://arxiv.org/abs/2302.04867
Currently, we support the multistep UniPC for both noise prediction models and data prediction models. We recommend
to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note
that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion).
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
solver_order (`int`, default `2`):
the order of UniPC, also the p in UniPC-p; can be any positive integer. Note that the effective order of
accuracy is `solver_order + 1` due to the UniC. We recommend to use `solver_order=2` for guided sampling,
and `solver_order=3` for unconditional sampling.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the
dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models
(such as stable-diffusion).
dynamic_thresholding_ratio (`float`, default `0.995`):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://arxiv.org/abs/2205.11487).
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
predict_x0 (`bool`, default `True`):
whether to use the updating algrithm on the predicted x0. See https://arxiv.org/abs/2211.01095 for details
solver_type (`str`, default `bh2`):
the solver type of UniPC. We recommend use `bh1` for unconditional sampling when steps < 10, and use `bh2`
otherwise.
lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
disable_corrector (`list`, default `[]`):
decide which step to disable the corrector. For large guidance scale, the misalignment between the
`epsilon_theta(x_t, c)`and `epsilon_theta(x_t^c, c)` might influence the convergence. This can be mitigated
by disable the corrector at the first few steps (e.g., disable_corrector=[0])
solver_p (`SchedulerMixin`, default `None`):
can be any other scheduler. If specified, the algorithm will become solver_p + UniC.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
timestep_spacing (`str`, default `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
predict_x0: bool = True,
solver_type: str = "bh2",
lower_order_final: bool = True,
disable_corrector: List[int] = [],
solver_p: SchedulerMixin = None,
use_karras_sigmas: Optional[bool] = False,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Currently we only support VP-type noise schedule
self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
if solver_type not in ["bh1", "bh2"]:
if solver_type in ["midpoint", "heun", "logrho"]:
self.register_to_config(solver_type="bh2")
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
self.predict_x0 = predict_x0
# setable values
self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order
self.timestep_list = [None] * solver_order
self.lower_order_nums = 0
self.disable_corrector = disable_corrector
self.solver_p = solver_p
self.last_sample = None
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas:
log_sigmas = np.log(sigmas)
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps = np.flip(timesteps).copy().astype(np.int64)
self.sigmas = torch.from_numpy(sigmas)
# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]
self.timesteps = torch.from_numpy(timesteps).to(device)
self.num_inference_steps = len(timesteps)
self.model_outputs = [
None,
] * self.config.solver_order
self.lower_order_nums = 0
self.last_sample = None
if self.solver_p:
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width)
sample = sample.to(dtype)
return sample
def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor:
r"""
Convert the model output to the corresponding type that the algorithm PC needs.
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
Returns:
`torch.FloatTensor`: the converted model output.
"""
if self.predict_x0:
if self.config.prediction_type == "epsilon":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = alpha_t * sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the UniPCMultistepScheduler."
)
if self.config.thresholding:
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
else:
if self.config.prediction_type == "epsilon":
return model_output
elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the UniPCMultistepScheduler."
)
def multistep_uni_p_bh_update(
self,
model_output: torch.FloatTensor,
prev_timestep: int,
sample: torch.FloatTensor,
order: int,
) -> torch.FloatTensor:
"""
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
Args:
model_output (`torch.FloatTensor`):
direct outputs from learned diffusion model at the current timestep.
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
order (`int`): the order of UniP at this step, also the p in UniPC-p.
Returns:
`torch.FloatTensor`: the sample tensor at the previous timestep.
"""
timestep_list = self.timestep_list
model_output_list = self.model_outputs
s0, t = self.timestep_list[-1], prev_timestep
m0 = model_output_list[-1]
x = sample
if self.solver_p:
x_t = self.solver_p.step(model_output, s0, x).prev_sample
return x_t
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h = lambda_t - lambda_s0
device = sample.device
rks = []
D1s = []
for i in range(1, order):
si = timestep_list[-(i + 1)]
mi = model_output_list[-(i + 1)]
lambda_si = self.lambda_t[si]
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
rks.append(1.0)
rks = torch.tensor(rks, device=device)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config.solver_type == "bh1":
B_h = hh
elif self.config.solver_type == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=device)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
else:
D1s = None
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - alpha_t * B_h * pred_res
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - sigma_t * B_h * pred_res
x_t = x_t.to(x.dtype)
return x_t
def multistep_uni_c_bh_update(
self,
this_model_output: torch.FloatTensor,
this_timestep: int,
last_sample: torch.FloatTensor,
this_sample: torch.FloatTensor,
order: int,
) -> torch.FloatTensor:
"""
One step for the UniC (B(h) version).
Args:
this_model_output (`torch.FloatTensor`): the model outputs at `x_t`
this_timestep (`int`): the current timestep `t`
last_sample (`torch.FloatTensor`): the generated sample before the last predictor: `x_{t-1}`
this_sample (`torch.FloatTensor`): the generated sample after the last predictor: `x_{t}`
order (`int`): the `p` of UniC-p at this step. Note that the effective order of accuracy
should be order + 1
Returns:
`torch.FloatTensor`: the corrected sample tensor at the current timestep.
"""
timestep_list = self.timestep_list
model_output_list = self.model_outputs
s0, t = timestep_list[-1], this_timestep
m0 = model_output_list[-1]
x = last_sample
x_t = this_sample
model_t = this_model_output
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h = lambda_t - lambda_s0
device = this_sample.device
rks = []
D1s = []
for i in range(1, order):
si = timestep_list[-(i + 1)]
mi = model_output_list[-(i + 1)]
lambda_si = self.lambda_t[si]
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
rks.append(1.0)
rks = torch.tensor(rks, device=device)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config.solver_type == "bh1":
B_h = hh
elif self.config.solver_type == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=device)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1)
else:
D1s = None
# for order 1, we use a simplified version
if order == 1:
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
rhos_c = torch.linalg.solve(R, b)
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
x_t = x_t.to(x.dtype)
return x_t
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Step function propagating the sample with the multistep UniPC.
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
use_corrector = (
step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None
)
model_output_convert = self.convert_model_output(model_output, timestep, sample)
if use_corrector:
sample = self.multistep_uni_c_bh_update(
this_model_output=model_output_convert,
this_timestep=timestep,
last_sample=self.last_sample,
this_sample=sample,
order=self.this_order,
)
# now prepare to run the predictor
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.timestep_list[i] = self.timestep_list[i + 1]
self.model_outputs[-1] = model_output_convert
self.timestep_list[-1] = timestep
if self.config.lower_order_final:
this_order = min(self.config.solver_order, len(self.timesteps) - step_index)
else:
this_order = self.config.solver_order
self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
assert self.this_order > 0
self.last_sample = sample
prev_sample = self.multistep_uni_p_bh_update(
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
prev_timestep=prev_timestep,
sample=sample,
order=self.this_order,
)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps