DEIS
Fast Sampling of Diffusion Models with Exponential Integrator.
Overview
Original paper can be found here. The original implementation can be found here.
DEISMultistepScheduler
class diffusers.DEISMultistepScheduler
< source >( num_train_timesteps: int = 1000 beta_start: float = 0.0001 beta_end: float = 0.02 beta_schedule: str = 'linear' trained_betas: typing.Optional[numpy.ndarray] = None solver_order: int = 2 prediction_type: str = 'epsilon' thresholding: bool = False dynamic_thresholding_ratio: float = 0.995 sample_max_value: float = 1.0 algorithm_type: str = 'deis' solver_type: str = 'logrho' lower_order_final: bool = True )
Parameters
-
num_train_timesteps (
int
) — number of diffusion steps used to train the model. -
beta_start (
float
) — the startingbeta
value of inference. -
beta_end (
float
) — the finalbeta
value. -
beta_schedule (
str
) — the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose fromlinear
,scaled_linear
, orsquaredcos_cap_v2
. -
trained_betas (
np.ndarray
, optional) — option to pass an array of betas directly to the constructor to bypassbeta_start
,beta_end
etc. -
solver_order (
int
, default2
) — the order of DEIS; can be1
or2
or3
. We recommend to usesolver_order=2
for guided sampling, andsolver_order=3
for unconditional sampling. -
prediction_type (
str
, defaultepsilon
) — indicates whether the model predicts the noise (epsilon), or the data /x0
. One ofepsilon
,sample
, orv-prediction
. -
thresholding (
bool
, defaultFalse
) — whether to use the “dynamic thresholding” method (introduced by Imagen, https://arxiv.org/abs/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). -
dynamic_thresholding_ratio (
float
, default0.995
) — the ratio for the dynamic thresholding method. Default is0.995
, the same as Imagen (https://arxiv.org/abs/2205.11487). -
sample_max_value (
float
, default1.0
) — the threshold value for dynamic thresholding. Valid woks whenthresholding=True
-
algorithm_type (
str
, defaultdeis
) — the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in the future -
lower_order_final (
bool
, defaultTrue
) — whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10.
DEIS (https://arxiv.org/abs/2204.13902) is a fast high order solver for diffusion ODEs. We slightly modify the polynomial fitting formula in log-rho space instead of the original linear t space in DEIS paper. The modification enjoys closed-form coefficients for exponential multistep update instead of replying on the numerical solver. More variants of DEIS can be found in https://github.com/qsh-zh/deis.
Currently, we support the log-rho multistep DEIS. We recommend to use solver_order=2 / 3
while solver_order=1
reduces to DDIM.
We also support the “dynamic thresholding” method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
diffusion models, you can set thresholding=True
to use the dynamic thresholding.
~ConfigMixin takes care of storing all config attributes that are passed in the scheduler’s __init__
function, such as num_train_timesteps
. They can be accessed via scheduler.config.num_train_timesteps
.
SchedulerMixin provides general loading and saving functionality via the SchedulerMixin.save_pretrained() and
from_pretrained() functions.
convert_model_output
< source >(
model_output: FloatTensor
timestep: int
sample: FloatTensor
)
→
torch.FloatTensor
Parameters
-
model_output (
torch.FloatTensor
) — direct output from learned diffusion model. -
timestep (
int
) — current discrete timestep in the diffusion chain. -
sample (
torch.FloatTensor
) — current instance of sample being created by diffusion process.
Returns
torch.FloatTensor
the converted model output.
Convert the model output to the corresponding type that the algorithm DEIS needs.
deis_first_order_update
< source >(
model_output: FloatTensor
timestep: int
prev_timestep: int
sample: FloatTensor
)
→
torch.FloatTensor
Parameters
-
model_output (
torch.FloatTensor
) — direct output from learned diffusion model. -
timestep (
int
) — current discrete timestep in the diffusion chain. -
prev_timestep (
int
) — previous discrete timestep in the diffusion chain. -
sample (
torch.FloatTensor
) — current instance of sample being created by diffusion process.
Returns
torch.FloatTensor
the sample tensor at the previous timestep.
One step for the first-order DEIS (equivalent to DDIM).
multistep_deis_second_order_update
< source >(
model_output_list: typing.List[torch.FloatTensor]
timestep_list: typing.List[int]
prev_timestep: int
sample: FloatTensor
)
→
torch.FloatTensor
Parameters
-
model_output_list (
List[torch.FloatTensor]
) — direct outputs from learned diffusion model at current and latter timesteps. -
timestep (
int
) — current and latter discrete timestep in the diffusion chain. -
prev_timestep (
int
) — previous discrete timestep in the diffusion chain. -
sample (
torch.FloatTensor
) — current instance of sample being created by diffusion process.
Returns
torch.FloatTensor
the sample tensor at the previous timestep.
One step for the second-order multistep DEIS.
multistep_deis_third_order_update
< source >(
model_output_list: typing.List[torch.FloatTensor]
timestep_list: typing.List[int]
prev_timestep: int
sample: FloatTensor
)
→
torch.FloatTensor
Parameters
-
model_output_list (
List[torch.FloatTensor]
) — direct outputs from learned diffusion model at current and latter timesteps. -
timestep (
int
) — current and latter discrete timestep in the diffusion chain. -
prev_timestep (
int
) — previous discrete timestep in the diffusion chain. -
sample (
torch.FloatTensor
) — current instance of sample being created by diffusion process.
Returns
torch.FloatTensor
the sample tensor at the previous timestep.
One step for the third-order multistep DEIS.
scale_model_input
< source >(
sample: FloatTensor
*args
**kwargs
)
→
torch.FloatTensor
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep.
set_timesteps
< source >( num_inference_steps: int device: typing.Union[str, torch.device] = None )
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
step
< source >(
model_output: FloatTensor
timestep: int
sample: FloatTensor
return_dict: bool = True
)
→
~scheduling_utils.SchedulerOutput
or tuple
Parameters
-
model_output (
torch.FloatTensor
) — direct output from learned diffusion model. -
timestep (
int
) — current discrete timestep in the diffusion chain. -
sample (
torch.FloatTensor
) — current instance of sample being created by diffusion process. -
return_dict (
bool
) — option for returning tuple rather than SchedulerOutput class
Returns
~scheduling_utils.SchedulerOutput
or tuple
~scheduling_utils.SchedulerOutput
if return_dict
is
True, otherwise a tuple
. When returning a tuple, the first element is the sample tensor.
Step function propagating the sample with the multistep DEIS.