# modified from https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py#L23 import torch import torch.nn as nn import numpy as np class ModelSamplingDiscreteFlow(nn.Module): """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models""" def __init__(self, num_train_timesteps=1000, shift=1.0, **kwargs): super().__init__() self.num_train_timesteps = num_train_timesteps self.shift = shift ts = self.to_sigma(torch.arange(1, num_train_timesteps + 1, 1)) # [1/1000, 1] self.register_buffer("sigmas", ts) @property def sigma_min(self): return self.sigmas[0] @property def sigma_max(self): return self.sigmas[-1] def to_timestep(self, sigma): return sigma * self.num_train_timesteps def to_sigma(self, timestep: torch.Tensor): timestep = timestep / self.num_train_timesteps if self.shift == 1.0: return timestep return self.shift * timestep / (1 + (self.shift - 1) * timestep) def uniform_sample_t(self, batch_size, device): ts = (self.sigma_max - self.sigma_min) * torch.rand(batch_size, device=device) + self.sigma_min return ts def calculate_denoised(self, sigma, model_output, model_input): # model ouput, vector field, v = dx = (x_1 - x_0) sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) return model_input - model_output * sigma def noise_scaling(self, sigma, noise, latent_image): return sigma * noise + (1.0 - sigma) * latent_image def add_noise(self, sample, noise=None, timesteps=None): # sample, B, L, D if timesteps is None: # Sample time step batch_size = sample.shape[0] sigmas = self.uniform_sample_t(batch_size, device=sample.device).to(dtype=sample.dtype) # (B,) timesteps = self.to_timestep(sigmas) else: timesteps = timesteps.to(device=sample.device, dtype=sample.dtype) sigmas = self.to_sigma(timesteps) sigmas = sigmas.view(-1, 1, 1) # (B, 1, 1) noise = torch.randn_like(sample) noisy_samples = sigmas * noise + (1.0 - sigmas) * sample return noisy_samples, noise, noise - sample, timesteps def set_timesteps(self, num_inference_steps, device=None): if num_inference_steps > self.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps start = self.to_timestep(self.sigma_max) end = self.to_timestep(self.sigma_min) timesteps = torch.linspace(start, end, num_inference_steps) self.timesteps = torch.from_numpy(np.array(timesteps)).to(device) def append_dims(self, x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim return x[(...,) + (None,) * dims_to_append] def to_d(self, x, sigma, denoised): """Converts a denoiser output to a Karras ODE derivative.""" return (x - denoised) / self.append_dims(sigma, x.ndim) @torch.no_grad() def step(self, model_output, timestep, sample, method="euler", **kwargs): """ Args: model_output (`torch.Tensor`): The direct output from learned diffusion model, direction (noise - x_0). timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process, x_t. method (`str`): ODE solver, `euler` or `dpmpp_2m` Returns: `tuple`: the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) sigma = self.to_sigma(timestep) prev_sigma = sigma - (self.sigma_max - self.sigma_min) / (self.num_inference_steps - 1) prev_sigma = 0.0 if prev_sigma < 0.0 else prev_sigma if method == "euler": """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" dt = prev_sigma - sigma prev_sample = sample + model_output * dt elif method == "dpmpp_2m": """DPM-Solver++(2M).""" raise NotImplementedError else: raise ValueError(f"Unsupported ode solver: {method}, only supports `euler` or `dpmpp_2m`") pred_original_sample = sample - model_output * sigma return ( prev_sample, pred_original_sample ) def get_pred_original_sample(self, model_output, timestep, sample): sigma = self.to_sigma(timestep).view(-1, 1, 1) pred_original_sample = sample - model_output * sigma return pred_original_sample