Spaces:
Runtime error
Runtime error
import warnings | |
from typing import Tuple, Union | |
import torch | |
from diffusers.schedulers.scheduling_lms_discrete import \ | |
LMSDiscreteScheduler as _LMSDiscreteScheduler | |
from diffusers.schedulers.scheduling_lms_discrete import \ | |
LMSDiscreteSchedulerOutput | |
class LMSDiscreteScheduler(_LMSDiscreteScheduler): | |
def step( | |
self, | |
model_output: torch.FloatTensor, | |
step_index: int, | |
sample: torch.FloatTensor, | |
order: int = 4, | |
return_dict: bool = True, | |
) -> Union[LMSDiscreteSchedulerOutput, Tuple]: | |
""" | |
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | |
process from the learned model outputs (most often the predicted noise). | |
Args: | |
model_output (`torch.FloatTensor`): direct output from learned diffusion model. | |
timestep (`float`): current timestep in the diffusion chain. | |
sample (`torch.FloatTensor`): | |
current instance of sample being created by diffusion process. | |
order: coefficient for multi-step inference. | |
return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class | |
Returns: | |
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`: | |
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. | |
When returning a tuple, the first element is the sample tensor. | |
""" | |
if not self.is_scale_input_called: | |
warnings.warn( | |
"The `scale_model_input` function should be called before `step` to ensure correct denoising. " | |
"See `StableDiffusionPipeline` for a usage example." | |
) | |
sigma = self.sigmas[step_index] | |
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | |
if self.config.prediction_type == "epsilon": | |
pred_original_sample = sample - sigma * model_output | |
elif self.config.prediction_type == "v_prediction": | |
# * c_out + input * c_skip | |
pred_original_sample = model_output * \ | |
(-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) | |
else: | |
raise ValueError( | |
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" | |
) | |
# 2. Convert to an ODE derivative | |
derivative = (sample - pred_original_sample) / sigma | |
self.derivatives.append(derivative) | |
if len(self.derivatives) > order: | |
self.derivatives.pop(0) | |
# 3. Compute linear multistep coefficients | |
order = min(step_index + 1, order) | |
lms_coeffs = [self.get_lms_coefficient( | |
order, step_index, curr_order) for curr_order in range(order)] | |
# 4. Compute previous sample based on the derivatives path | |
prev_sample = sample + sum( | |
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives)) | |
) | |
if not return_dict: | |
return (prev_sample,) | |
return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) | |
def scale_model_input( | |
self, | |
sample: torch.FloatTensor, | |
iteration: int | |
) -> torch.FloatTensor: | |
""" | |
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. | |
Args: | |
sample (`torch.FloatTensor`): input sample | |
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain | |
Returns: | |
`torch.FloatTensor`: scaled input sample | |
""" | |
sample = sample / ((self.sigmas[iteration]**2 + 1) ** 0.5) | |
self.is_scale_input_called = True | |
return sample |