# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ def alpha_bar(time_step): return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): """ UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders. UniPC is by desinged model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional sampling. It can also be applied to both noise prediction model and data prediction model. The corrector UniC can be also applied after any off-the-shelf solvers to increase the order of accuracy. For more details, see the original paper: https://arxiv.org/abs/2302.04867 Currently, we support the multistep UniPC for both noise prediction models and data prediction models. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. solver_order (`int`, default `2`): the order of UniPC, also the p in UniPC-p; can be any positive integer. Note that the effective order of accuracy is `solver_order + 1` due to the UniC. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. predict_x0 (`bool`, default `True`): whether to use the updating algrithm on the predicted x0. See https://arxiv.org/abs/2211.01095 for details solver_type (`str`, default `bh2`): the solver type of UniPC. We recommend use `bh1` for unconditional sampling when steps < 10, and use `bh2` otherwise. lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. disable_corrector (`list`, default `[]`): decide which step to disable the corrector. For large guidance scale, the misalignment between the `epsilon_theta(x_t, c)`and `epsilon_theta(x_t^c, c)` might influence the convergence. This can be mitigated by disable the corrector at the first few steps (e.g., disable_corrector=[0]) solver_p (`SchedulerMixin`, default `None`): can be any other scheduler. If specified, the algorithm will become solver_p + UniC. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, predict_x0: bool = True, solver_type: str = "bh2", lower_order_final: bool = True, disable_corrector: List[int] = [], solver_p: SchedulerMixin = None, use_karras_sigmas: Optional[bool] = False, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 if solver_type not in ["bh1", "bh2"]: if solver_type in ["midpoint", "heun", "logrho"]: self.register_to_config(solver_type="bh2") else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") self.predict_x0 = predict_x0 # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.timestep_list = [None] * solver_order self.lower_order_nums = 0 self.disable_corrector = disable_corrector self.solver_p = solver_p self.last_sample = None def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) self.sigmas = torch.from_numpy(sigmas) # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) timesteps = timesteps[np.sort(unique_indices)] self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 self.last_sample = None if self.solver_p: self.solver_p.set_timesteps(self.num_inference_steps, device=device) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: r""" Convert the model output to the corresponding type that the algorithm PC needs. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the converted model output. """ if self.predict_x0: if self.config.prediction_type == "epsilon": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the UniPCMultistepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred else: if self.config.prediction_type == "epsilon": return model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t return epsilon elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample return epsilon else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the UniPCMultistepScheduler." ) def multistep_uni_p_bh_update( self, model_output: torch.FloatTensor, prev_timestep: int, sample: torch.FloatTensor, order: int, ) -> torch.FloatTensor: """ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. Args: model_output (`torch.FloatTensor`): direct outputs from learned diffusion model at the current timestep. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. order (`int`): the order of UniP at this step, also the p in UniPC-p. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ timestep_list = self.timestep_list model_output_list = self.model_outputs s0, t = self.timestep_list[-1], prev_timestep m0 = model_output_list[-1] x = sample if self.solver_p: x_t = self.solver_p.step(model_output, s0, x).prev_sample return x_t lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h = lambda_t - lambda_s0 device = sample.device rks = [] D1s = [] for i in range(1, order): si = timestep_list[-(i + 1)] mi = model_output_list[-(i + 1)] lambda_si = self.lambda_t[si] rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) rks.append(1.0) rks = torch.tensor(rks, device=device) R = [] b = [] hh = -h if self.predict_x0 else h h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 if self.config.solver_type == "bh1": B_h = hh elif self.config.solver_type == "bh2": B_h = torch.expm1(hh) else: raise NotImplementedError() for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) b = torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) # (B, K) # for order 2, we use a simplified version if order == 2: rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) else: rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) else: D1s = None if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) else: pred_res = 0 x_t = x_t_ - alpha_t * B_h * pred_res else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) else: pred_res = 0 x_t = x_t_ - sigma_t * B_h * pred_res x_t = x_t.to(x.dtype) return x_t def multistep_uni_c_bh_update( self, this_model_output: torch.FloatTensor, this_timestep: int, last_sample: torch.FloatTensor, this_sample: torch.FloatTensor, order: int, ) -> torch.FloatTensor: """ One step for the UniC (B(h) version). Args: this_model_output (`torch.FloatTensor`): the model outputs at `x_t` this_timestep (`int`): the current timestep `t` last_sample (`torch.FloatTensor`): the generated sample before the last predictor: `x_{t-1}` this_sample (`torch.FloatTensor`): the generated sample after the last predictor: `x_{t}` order (`int`): the `p` of UniC-p at this step. Note that the effective order of accuracy should be order + 1 Returns: `torch.FloatTensor`: the corrected sample tensor at the current timestep. """ timestep_list = self.timestep_list model_output_list = self.model_outputs s0, t = timestep_list[-1], this_timestep m0 = model_output_list[-1] x = last_sample x_t = this_sample model_t = this_model_output lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h = lambda_t - lambda_s0 device = this_sample.device rks = [] D1s = [] for i in range(1, order): si = timestep_list[-(i + 1)] mi = model_output_list[-(i + 1)] lambda_si = self.lambda_t[si] rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) rks.append(1.0) rks = torch.tensor(rks, device=device) R = [] b = [] hh = -h if self.predict_x0 else h h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 if self.config.solver_type == "bh1": B_h = hh elif self.config.solver_type == "bh2": B_h = torch.expm1(hh) else: raise NotImplementedError() for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) b = torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) else: D1s = None # for order 1, we use a simplified version if order == 1: rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) else: rhos_c = torch.linalg.solve(R, b) if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) else: corr_res = 0 D1_t = model_t - m0 x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) else: corr_res = 0 D1_t = model_t - m0 x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) x_t = x_t.to(x.dtype) return x_t def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the multistep UniPC. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero() if len(step_index) == 0: step_index = len(self.timesteps) - 1 else: step_index = step_index.item() use_corrector = ( step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None ) model_output_convert = self.convert_model_output(model_output, timestep, sample) if use_corrector: sample = self.multistep_uni_c_bh_update( this_model_output=model_output_convert, this_timestep=timestep, last_sample=self.last_sample, this_sample=sample, order=self.this_order, ) # now prepare to run the predictor prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.timestep_list[i] = self.timestep_list[i + 1] self.model_outputs[-1] = model_output_convert self.timestep_list[-1] = timestep if self.config.lower_order_final: this_order = min(self.config.solver_order, len(self.timesteps) - step_index) else: this_order = self.config.solver_order self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep assert self.this_order > 0 self.last_sample = sample prev_sample = self.multistep_uni_p_bh_update( model_output=model_output, # pass the original non-converted model output, in case solver-p is used prev_timestep=prev_timestep, sample=sample, order=self.this_order, ) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample Returns: `torch.FloatTensor`: scaled input sample """ return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps