|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import math | 
					
						
						|  | from dataclasses import dataclass | 
					
						
						|  | from typing import List, Optional, Tuple, Union | 
					
						
						|  |  | 
					
						
						|  | import numpy as np | 
					
						
						|  | import torch | 
					
						
						|  | import torchsde | 
					
						
						|  |  | 
					
						
						|  | from ..configuration_utils import ConfigMixin, register_to_config | 
					
						
						|  | from ..utils import BaseOutput, is_scipy_available | 
					
						
						|  | from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if is_scipy_available(): | 
					
						
						|  | import scipy.stats | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  |  | 
					
						
						|  | class DPMSolverSDESchedulerOutput(BaseOutput): | 
					
						
						|  | """ | 
					
						
						|  | Output class for the scheduler's `step` function output. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): | 
					
						
						|  | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the | 
					
						
						|  | denoising loop. | 
					
						
						|  | pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): | 
					
						
						|  | The predicted denoised sample `(x_{0})` based on the model output from the current timestep. | 
					
						
						|  | `pred_original_sample` can be used to preview progress or for guidance. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | prev_sample: torch.Tensor | 
					
						
						|  | pred_original_sample: Optional[torch.Tensor] = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BatchedBrownianTree: | 
					
						
						|  | """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, x, t0, t1, seed=None, **kwargs): | 
					
						
						|  | t0, t1, self.sign = self.sort(t0, t1) | 
					
						
						|  | w0 = kwargs.get("w0", torch.zeros_like(x)) | 
					
						
						|  | if seed is None: | 
					
						
						|  | seed = torch.randint(0, 2**63 - 1, []).item() | 
					
						
						|  | self.batched = True | 
					
						
						|  | try: | 
					
						
						|  | assert len(seed) == x.shape[0] | 
					
						
						|  | w0 = w0[0] | 
					
						
						|  | except TypeError: | 
					
						
						|  | seed = [seed] | 
					
						
						|  | self.batched = False | 
					
						
						|  | self.trees = [ | 
					
						
						|  | torchsde.BrownianInterval( | 
					
						
						|  | t0=t0, | 
					
						
						|  | t1=t1, | 
					
						
						|  | size=w0.shape, | 
					
						
						|  | dtype=w0.dtype, | 
					
						
						|  | device=w0.device, | 
					
						
						|  | entropy=s, | 
					
						
						|  | tol=1e-6, | 
					
						
						|  | pool_size=24, | 
					
						
						|  | halfway_tree=True, | 
					
						
						|  | ) | 
					
						
						|  | for s in seed | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def sort(a, b): | 
					
						
						|  | return (a, b, 1) if a < b else (b, a, -1) | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, t0, t1): | 
					
						
						|  | t0, t1, sign = self.sort(t0, t1) | 
					
						
						|  | w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) | 
					
						
						|  | return w if self.batched else w[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BrownianTreeNoiseSampler: | 
					
						
						|  | """A noise sampler backed by a torchsde.BrownianTree. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x (Tensor): The tensor whose shape, device and dtype to use to generate | 
					
						
						|  | random samples. | 
					
						
						|  | sigma_min (float): The low end of the valid interval. | 
					
						
						|  | sigma_max (float): The high end of the valid interval. | 
					
						
						|  | seed (int or List[int]): The random seed. If a list of seeds is | 
					
						
						|  | supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each | 
					
						
						|  | with its own seed. | 
					
						
						|  | transform (callable): A function that maps sigma to the sampler's | 
					
						
						|  | internal timestep. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): | 
					
						
						|  | self.transform = transform | 
					
						
						|  | t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) | 
					
						
						|  | self.tree = BatchedBrownianTree(x, t0, t1, seed) | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, sigma, sigma_next): | 
					
						
						|  | t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) | 
					
						
						|  | return self.tree(t0, t1) / (t1 - t0).abs().sqrt() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def betas_for_alpha_bar( | 
					
						
						|  | num_diffusion_timesteps, | 
					
						
						|  | max_beta=0.999, | 
					
						
						|  | alpha_transform_type="cosine", | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of | 
					
						
						|  | (1-beta) over time from t = [0,1]. | 
					
						
						|  |  | 
					
						
						|  | Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up | 
					
						
						|  | to that part of the diffusion process. | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | num_diffusion_timesteps (`int`): the number of betas to produce. | 
					
						
						|  | max_beta (`float`): the maximum beta to use; use values lower than 1 to | 
					
						
						|  | prevent singularities. | 
					
						
						|  | alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. | 
					
						
						|  | Choose from `cosine` or `exp` | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | betas (`np.ndarray`): the betas used by the scheduler to step the model outputs | 
					
						
						|  | """ | 
					
						
						|  | if alpha_transform_type == "cosine": | 
					
						
						|  |  | 
					
						
						|  | def alpha_bar_fn(t): | 
					
						
						|  | return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 | 
					
						
						|  |  | 
					
						
						|  | elif alpha_transform_type == "exp": | 
					
						
						|  |  | 
					
						
						|  | def alpha_bar_fn(t): | 
					
						
						|  | return math.exp(t * -12.0) | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") | 
					
						
						|  |  | 
					
						
						|  | betas = [] | 
					
						
						|  | for i in range(num_diffusion_timesteps): | 
					
						
						|  | t1 = i / num_diffusion_timesteps | 
					
						
						|  | t2 = (i + 1) / num_diffusion_timesteps | 
					
						
						|  | betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) | 
					
						
						|  | return torch.tensor(betas, dtype=torch.float32) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): | 
					
						
						|  | """ | 
					
						
						|  | DPMSolverSDEScheduler implements the stochastic sampler from the [Elucidating the Design Space of Diffusion-Based | 
					
						
						|  | Generative Models](https://huggingface.co/papers/2206.00364) paper. | 
					
						
						|  |  | 
					
						
						|  | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic | 
					
						
						|  | methods the library implements for all schedulers such as loading and saving. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | num_train_timesteps (`int`, defaults to 1000): | 
					
						
						|  | The number of diffusion steps to train the model. | 
					
						
						|  | beta_start (`float`, defaults to 0.00085): | 
					
						
						|  | The starting `beta` value of inference. | 
					
						
						|  | beta_end (`float`, defaults to 0.012): | 
					
						
						|  | The final `beta` value. | 
					
						
						|  | beta_schedule (`str`, defaults to `"linear"`): | 
					
						
						|  | The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from | 
					
						
						|  | `linear` or `scaled_linear`. | 
					
						
						|  | trained_betas (`np.ndarray`, *optional*): | 
					
						
						|  | Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. | 
					
						
						|  | prediction_type (`str`, defaults to `epsilon`, *optional*): | 
					
						
						|  | Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), | 
					
						
						|  | `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen | 
					
						
						|  | Video](https://imagen.research.google/video/paper.pdf) paper). | 
					
						
						|  | use_karras_sigmas (`bool`, *optional*, defaults to `False`): | 
					
						
						|  | Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, | 
					
						
						|  | the sigmas are determined according to a sequence of noise levels {Οi}. | 
					
						
						|  | use_exponential_sigmas (`bool`, *optional*, defaults to `False`): | 
					
						
						|  | Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. | 
					
						
						|  | use_beta_sigmas (`bool`, *optional*, defaults to `False`): | 
					
						
						|  | Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta | 
					
						
						|  | Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. | 
					
						
						|  | noise_sampler_seed (`int`, *optional*, defaults to `None`): | 
					
						
						|  | The random seed to use for the noise sampler. If `None`, a random seed is generated. | 
					
						
						|  | timestep_spacing (`str`, defaults to `"linspace"`): | 
					
						
						|  | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and | 
					
						
						|  | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. | 
					
						
						|  | steps_offset (`int`, defaults to 0): | 
					
						
						|  | An offset added to the inference steps, as required by some model families. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | _compatibles = [e.name for e in KarrasDiffusionSchedulers] | 
					
						
						|  | order = 2 | 
					
						
						|  |  | 
					
						
						|  | @register_to_config | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | num_train_timesteps: int = 1000, | 
					
						
						|  | beta_start: float = 0.00085, | 
					
						
						|  | beta_end: float = 0.012, | 
					
						
						|  | beta_schedule: str = "linear", | 
					
						
						|  | trained_betas: Optional[Union[np.ndarray, List[float]]] = None, | 
					
						
						|  | prediction_type: str = "epsilon", | 
					
						
						|  | use_karras_sigmas: Optional[bool] = False, | 
					
						
						|  | use_exponential_sigmas: Optional[bool] = False, | 
					
						
						|  | use_beta_sigmas: Optional[bool] = False, | 
					
						
						|  | noise_sampler_seed: Optional[int] = None, | 
					
						
						|  | timestep_spacing: str = "linspace", | 
					
						
						|  | steps_offset: int = 0, | 
					
						
						|  | ): | 
					
						
						|  | if self.config.use_beta_sigmas and not is_scipy_available(): | 
					
						
						|  | raise ImportError("Make sure to install scipy if you want to use beta sigmas.") | 
					
						
						|  | if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." | 
					
						
						|  | ) | 
					
						
						|  | if trained_betas is not None: | 
					
						
						|  | self.betas = torch.tensor(trained_betas, dtype=torch.float32) | 
					
						
						|  | elif beta_schedule == "linear": | 
					
						
						|  | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | 
					
						
						|  | elif beta_schedule == "scaled_linear": | 
					
						
						|  |  | 
					
						
						|  | self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | 
					
						
						|  | elif beta_schedule == "squaredcos_cap_v2": | 
					
						
						|  |  | 
					
						
						|  | self.betas = betas_for_alpha_bar(num_train_timesteps) | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") | 
					
						
						|  |  | 
					
						
						|  | self.alphas = 1.0 - self.betas | 
					
						
						|  | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.set_timesteps(num_train_timesteps, None, num_train_timesteps) | 
					
						
						|  | self.use_karras_sigmas = use_karras_sigmas | 
					
						
						|  | self.noise_sampler = None | 
					
						
						|  | self.noise_sampler_seed = noise_sampler_seed | 
					
						
						|  | self._step_index = None | 
					
						
						|  | self._begin_index = None | 
					
						
						|  | self.sigmas = self.sigmas.to("cpu") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def index_for_timestep(self, timestep, schedule_timesteps=None): | 
					
						
						|  | if schedule_timesteps is None: | 
					
						
						|  | schedule_timesteps = self.timesteps | 
					
						
						|  |  | 
					
						
						|  | indices = (schedule_timesteps == timestep).nonzero() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pos = 1 if len(indices) > 1 else 0 | 
					
						
						|  |  | 
					
						
						|  | return indices[pos].item() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _init_step_index(self, timestep): | 
					
						
						|  | if self.begin_index is None: | 
					
						
						|  | if isinstance(timestep, torch.Tensor): | 
					
						
						|  | timestep = timestep.to(self.timesteps.device) | 
					
						
						|  | self._step_index = self.index_for_timestep(timestep) | 
					
						
						|  | else: | 
					
						
						|  | self._step_index = self._begin_index | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def init_noise_sigma(self): | 
					
						
						|  |  | 
					
						
						|  | if self.config.timestep_spacing in ["linspace", "trailing"]: | 
					
						
						|  | return self.sigmas.max() | 
					
						
						|  |  | 
					
						
						|  | return (self.sigmas.max() ** 2 + 1) ** 0.5 | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def step_index(self): | 
					
						
						|  | """ | 
					
						
						|  | The index counter for current timestep. It will increase 1 after each scheduler step. | 
					
						
						|  | """ | 
					
						
						|  | return self._step_index | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def begin_index(self): | 
					
						
						|  | """ | 
					
						
						|  | The index for the first timestep. It should be set from pipeline with `set_begin_index` method. | 
					
						
						|  | """ | 
					
						
						|  | return self._begin_index | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def set_begin_index(self, begin_index: int = 0): | 
					
						
						|  | """ | 
					
						
						|  | Sets the begin index for the scheduler. This function should be run from pipeline before the inference. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | begin_index (`int`): | 
					
						
						|  | The begin index for the scheduler. | 
					
						
						|  | """ | 
					
						
						|  | self._begin_index = begin_index | 
					
						
						|  |  | 
					
						
						|  | def scale_model_input( | 
					
						
						|  | self, | 
					
						
						|  | sample: torch.Tensor, | 
					
						
						|  | timestep: Union[float, torch.Tensor], | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the | 
					
						
						|  | current timestep. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | sample (`torch.Tensor`): | 
					
						
						|  | The input sample. | 
					
						
						|  | timestep (`int`, *optional*): | 
					
						
						|  | The current timestep in the diffusion chain. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | `torch.Tensor`: | 
					
						
						|  | A scaled input sample. | 
					
						
						|  | """ | 
					
						
						|  | if self.step_index is None: | 
					
						
						|  | self._init_step_index(timestep) | 
					
						
						|  |  | 
					
						
						|  | sigma = self.sigmas[self.step_index] | 
					
						
						|  | sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma | 
					
						
						|  | sample = sample / ((sigma_input**2 + 1) ** 0.5) | 
					
						
						|  | return sample | 
					
						
						|  |  | 
					
						
						|  | def set_timesteps( | 
					
						
						|  | self, | 
					
						
						|  | num_inference_steps: int, | 
					
						
						|  | device: Union[str, torch.device] = None, | 
					
						
						|  | num_train_timesteps: Optional[int] = None, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Sets the discrete timesteps used for the diffusion chain (to be run before inference). | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | num_inference_steps (`int`): | 
					
						
						|  | The number of diffusion steps used when generating samples with a pre-trained model. | 
					
						
						|  | device (`str` or `torch.device`, *optional*): | 
					
						
						|  | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | 
					
						
						|  | """ | 
					
						
						|  | self.num_inference_steps = num_inference_steps | 
					
						
						|  |  | 
					
						
						|  | num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.config.timestep_spacing == "linspace": | 
					
						
						|  | timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() | 
					
						
						|  | elif self.config.timestep_spacing == "leading": | 
					
						
						|  | step_ratio = num_train_timesteps // self.num_inference_steps | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) | 
					
						
						|  | timesteps += self.config.steps_offset | 
					
						
						|  | elif self.config.timestep_spacing == "trailing": | 
					
						
						|  | step_ratio = num_train_timesteps / self.num_inference_steps | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) | 
					
						
						|  | timesteps -= 1 | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) | 
					
						
						|  | log_sigmas = np.log(sigmas) | 
					
						
						|  | sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) | 
					
						
						|  |  | 
					
						
						|  | if self.config.use_karras_sigmas: | 
					
						
						|  | sigmas = self._convert_to_karras(in_sigmas=sigmas) | 
					
						
						|  | timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) | 
					
						
						|  | elif self.config.use_exponential_sigmas: | 
					
						
						|  | sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) | 
					
						
						|  | timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) | 
					
						
						|  | elif self.config.use_beta_sigmas: | 
					
						
						|  | sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) | 
					
						
						|  | timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) | 
					
						
						|  |  | 
					
						
						|  | second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas) | 
					
						
						|  |  | 
					
						
						|  | sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) | 
					
						
						|  | sigmas = torch.from_numpy(sigmas).to(device=device) | 
					
						
						|  | self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) | 
					
						
						|  |  | 
					
						
						|  | timesteps = torch.from_numpy(timesteps) | 
					
						
						|  | second_order_timesteps = torch.from_numpy(second_order_timesteps) | 
					
						
						|  | timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) | 
					
						
						|  | timesteps[1::2] = second_order_timesteps | 
					
						
						|  |  | 
					
						
						|  | if str(device).startswith("mps"): | 
					
						
						|  |  | 
					
						
						|  | self.timesteps = timesteps.to(device, dtype=torch.float32) | 
					
						
						|  | else: | 
					
						
						|  | self.timesteps = timesteps.to(device=device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.sample = None | 
					
						
						|  | self.mid_point_sigma = None | 
					
						
						|  |  | 
					
						
						|  | self._step_index = None | 
					
						
						|  | self._begin_index = None | 
					
						
						|  | self.sigmas = self.sigmas.to("cpu") | 
					
						
						|  | self.noise_sampler = None | 
					
						
						|  |  | 
					
						
						|  | def _second_order_timesteps(self, sigmas, log_sigmas): | 
					
						
						|  | def sigma_fn(_t): | 
					
						
						|  | return np.exp(-_t) | 
					
						
						|  |  | 
					
						
						|  | def t_fn(_sigma): | 
					
						
						|  | return -np.log(_sigma) | 
					
						
						|  |  | 
					
						
						|  | midpoint_ratio = 0.5 | 
					
						
						|  | t = t_fn(sigmas) | 
					
						
						|  | delta_time = np.diff(t) | 
					
						
						|  | t_proposed = t[:-1] + delta_time * midpoint_ratio | 
					
						
						|  | sig_proposed = sigma_fn(t_proposed) | 
					
						
						|  | timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sig_proposed]) | 
					
						
						|  | return timesteps | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _sigma_to_t(self, sigma, log_sigmas): | 
					
						
						|  |  | 
					
						
						|  | log_sigma = np.log(np.maximum(sigma, 1e-10)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | dists = log_sigma - log_sigmas[:, np.newaxis] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) | 
					
						
						|  | high_idx = low_idx + 1 | 
					
						
						|  |  | 
					
						
						|  | low = log_sigmas[low_idx] | 
					
						
						|  | high = log_sigmas[high_idx] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | w = (low - log_sigma) / (low - high) | 
					
						
						|  | w = np.clip(w, 0, 1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | t = (1 - w) * low_idx + w * high_idx | 
					
						
						|  | t = t.reshape(sigma.shape) | 
					
						
						|  | return t | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | """Constructs the noise schedule of Karras et al. (2022).""" | 
					
						
						|  |  | 
					
						
						|  | sigma_min: float = in_sigmas[-1].item() | 
					
						
						|  | sigma_max: float = in_sigmas[0].item() | 
					
						
						|  |  | 
					
						
						|  | rho = 7.0 | 
					
						
						|  | ramp = np.linspace(0, 1, self.num_inference_steps) | 
					
						
						|  | min_inv_rho = sigma_min ** (1 / rho) | 
					
						
						|  | max_inv_rho = sigma_max ** (1 / rho) | 
					
						
						|  | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho | 
					
						
						|  | return sigmas | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: | 
					
						
						|  | """Constructs an exponential noise schedule.""" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if hasattr(self.config, "sigma_min"): | 
					
						
						|  | sigma_min = self.config.sigma_min | 
					
						
						|  | else: | 
					
						
						|  | sigma_min = None | 
					
						
						|  |  | 
					
						
						|  | if hasattr(self.config, "sigma_max"): | 
					
						
						|  | sigma_max = self.config.sigma_max | 
					
						
						|  | else: | 
					
						
						|  | sigma_max = None | 
					
						
						|  |  | 
					
						
						|  | sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() | 
					
						
						|  | sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() | 
					
						
						|  |  | 
					
						
						|  | sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) | 
					
						
						|  | return sigmas | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _convert_to_beta( | 
					
						
						|  | self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if hasattr(self.config, "sigma_min"): | 
					
						
						|  | sigma_min = self.config.sigma_min | 
					
						
						|  | else: | 
					
						
						|  | sigma_min = None | 
					
						
						|  |  | 
					
						
						|  | if hasattr(self.config, "sigma_max"): | 
					
						
						|  | sigma_max = self.config.sigma_max | 
					
						
						|  | else: | 
					
						
						|  | sigma_max = None | 
					
						
						|  |  | 
					
						
						|  | sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() | 
					
						
						|  | sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() | 
					
						
						|  |  | 
					
						
						|  | sigmas = np.array( | 
					
						
						|  | [ | 
					
						
						|  | sigma_min + (ppf * (sigma_max - sigma_min)) | 
					
						
						|  | for ppf in [ | 
					
						
						|  | scipy.stats.beta.ppf(timestep, alpha, beta) | 
					
						
						|  | for timestep in 1 - np.linspace(0, 1, num_inference_steps) | 
					
						
						|  | ] | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | return sigmas | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def state_in_first_order(self): | 
					
						
						|  | return self.sample is None | 
					
						
						|  |  | 
					
						
						|  | def step( | 
					
						
						|  | self, | 
					
						
						|  | model_output: Union[torch.Tensor, np.ndarray], | 
					
						
						|  | timestep: Union[float, torch.Tensor], | 
					
						
						|  | sample: Union[torch.Tensor, np.ndarray], | 
					
						
						|  | return_dict: bool = True, | 
					
						
						|  | s_noise: float = 1.0, | 
					
						
						|  | ) -> Union[DPMSolverSDESchedulerOutput, Tuple]: | 
					
						
						|  | """ | 
					
						
						|  | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion | 
					
						
						|  | process from the learned model outputs (most often the predicted noise). | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | model_output (`torch.Tensor` or `np.ndarray`): | 
					
						
						|  | The direct output from learned diffusion model. | 
					
						
						|  | timestep (`float` or `torch.Tensor`): | 
					
						
						|  | The current discrete timestep in the diffusion chain. | 
					
						
						|  | sample (`torch.Tensor` or `np.ndarray`): | 
					
						
						|  | A current instance of a sample created by the diffusion process. | 
					
						
						|  | return_dict (`bool`): | 
					
						
						|  | Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or | 
					
						
						|  | tuple. | 
					
						
						|  | s_noise (`float`, *optional*, defaults to 1.0): | 
					
						
						|  | Scaling factor for noise added to the sample. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or `tuple`: | 
					
						
						|  | If return_dict is `True`, [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] is | 
					
						
						|  | returned, otherwise a tuple is returned where the first element is the sample tensor. | 
					
						
						|  | """ | 
					
						
						|  | if self.step_index is None: | 
					
						
						|  | self._init_step_index(timestep) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.noise_sampler is None: | 
					
						
						|  | min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max() | 
					
						
						|  | self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def sigma_fn(_t: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | return _t.neg().exp() | 
					
						
						|  |  | 
					
						
						|  | def t_fn(_sigma: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | return _sigma.log().neg() | 
					
						
						|  |  | 
					
						
						|  | if self.state_in_first_order: | 
					
						
						|  | sigma = self.sigmas[self.step_index] | 
					
						
						|  | sigma_next = self.sigmas[self.step_index + 1] | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | sigma = self.sigmas[self.step_index - 1] | 
					
						
						|  | sigma_next = self.sigmas[self.step_index] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | midpoint_ratio = 0.5 | 
					
						
						|  | t, t_next = t_fn(sigma), t_fn(sigma_next) | 
					
						
						|  | delta_time = t_next - t | 
					
						
						|  | t_proposed = t + delta_time * midpoint_ratio | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.config.prediction_type == "epsilon": | 
					
						
						|  | sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed) | 
					
						
						|  | pred_original_sample = sample - sigma_input * model_output | 
					
						
						|  | elif self.config.prediction_type == "v_prediction": | 
					
						
						|  | sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed) | 
					
						
						|  | pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( | 
					
						
						|  | sample / (sigma_input**2 + 1) | 
					
						
						|  | ) | 
					
						
						|  | elif self.config.prediction_type == "sample": | 
					
						
						|  | raise NotImplementedError("prediction_type not implemented yet: sample") | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if sigma_next == 0: | 
					
						
						|  | derivative = (sample - pred_original_sample) / sigma | 
					
						
						|  | dt = sigma_next - sigma | 
					
						
						|  | prev_sample = sample + derivative * dt | 
					
						
						|  | else: | 
					
						
						|  | if self.state_in_first_order: | 
					
						
						|  | t_next = t_proposed | 
					
						
						|  | else: | 
					
						
						|  | sample = self.sample | 
					
						
						|  |  | 
					
						
						|  | sigma_from = sigma_fn(t) | 
					
						
						|  | sigma_to = sigma_fn(t_next) | 
					
						
						|  | sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) | 
					
						
						|  | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 | 
					
						
						|  | ancestral_t = t_fn(sigma_down) | 
					
						
						|  | prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - ( | 
					
						
						|  | t - ancestral_t | 
					
						
						|  | ).expm1() * pred_original_sample | 
					
						
						|  | prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * sigma_up | 
					
						
						|  |  | 
					
						
						|  | if self.state_in_first_order: | 
					
						
						|  |  | 
					
						
						|  | self.sample = sample | 
					
						
						|  | self.mid_point_sigma = sigma_fn(t_next) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | self.sample = None | 
					
						
						|  | self.mid_point_sigma = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._step_index += 1 | 
					
						
						|  |  | 
					
						
						|  | if not return_dict: | 
					
						
						|  | return ( | 
					
						
						|  | prev_sample, | 
					
						
						|  | pred_original_sample, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return DPMSolverSDESchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def add_noise( | 
					
						
						|  | self, | 
					
						
						|  | original_samples: torch.Tensor, | 
					
						
						|  | noise: torch.Tensor, | 
					
						
						|  | timesteps: torch.Tensor, | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  |  | 
					
						
						|  | sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) | 
					
						
						|  | if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): | 
					
						
						|  |  | 
					
						
						|  | schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) | 
					
						
						|  | timesteps = timesteps.to(original_samples.device, dtype=torch.float32) | 
					
						
						|  | else: | 
					
						
						|  | schedule_timesteps = self.timesteps.to(original_samples.device) | 
					
						
						|  | timesteps = timesteps.to(original_samples.device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.begin_index is None: | 
					
						
						|  | step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] | 
					
						
						|  | elif self.step_index is not None: | 
					
						
						|  |  | 
					
						
						|  | step_indices = [self.step_index] * timesteps.shape[0] | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | step_indices = [self.begin_index] * timesteps.shape[0] | 
					
						
						|  |  | 
					
						
						|  | sigma = sigmas[step_indices].flatten() | 
					
						
						|  | while len(sigma.shape) < len(original_samples.shape): | 
					
						
						|  | sigma = sigma.unsqueeze(-1) | 
					
						
						|  |  | 
					
						
						|  | noisy_samples = original_samples + noise * sigma | 
					
						
						|  | return noisy_samples | 
					
						
						|  |  | 
					
						
						|  | def __len__(self): | 
					
						
						|  | return self.config.num_train_timesteps | 
					
						
						|  |  |