"""SAMPLING ONLY.""" import torch import numpy as np from tqdm import tqdm from functools import partial from typing import List, Optional, Tuple, Union from ldm.util import randn_tensor from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ extract_into_tensor class LCMSampler(object): def __init__(self, model, **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.original_inference_steps = 100 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, self.ddpm_num_timesteps)[::-1].copy().astype(np.int64)) self.custom_timesteps = False self.timestep_scaling = 10.0 self.prediction_type = 'epsilon' def register_buffer(self, name, attr): if type(attr) == torch.Tensor: if attr.device != torch.device("cuda"): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) def make_schedule(self, ddim_discretize="uniform", verbose=True): # self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, # num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) # alphas_cumprod = self.model.alphas_cumprod # assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' # beta_start = 0.00085 # beta_end = 0.012 # self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, self.ddpm_num_timesteps, dtype=torch.float32) ** 2 # self.alphas = 1.0 - self.betas # self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) alphas_cumprod = self.model.alphas_cumprod assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) # # calculations for diffusion q(x_t | x_{t-1}) and others # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) # self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) # self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) # self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) # # ddim sampling parameters # ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), # ddim_timesteps=self.ddim_timesteps, # eta=ddim_eta,verbose=verbose) # self.register_buffer('ddim_sigmas', ddim_sigmas) # self.register_buffer('ddim_alphas', ddim_alphas) # self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) # self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) # sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( # (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( # 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) # self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) def progress_bar(self, iterable=None, total=None): if not hasattr(self, "_progress_bar_config"): self._progress_bar_config = {} elif not isinstance(self._progress_bar_config, dict): raise ValueError( f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." ) if iterable is not None: return tqdm(iterable, **self._progress_bar_config) elif total is not None: return tqdm(total=total, **self._progress_bar_config) else: raise ValueError("Either `total` or `iterable` has to be defined.") def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: timesteps (`torch.Tensor`): generate embedding vectors at these timesteps embedding_dim (`int`, *optional*, defaults to 512): dimension of the embeddings to generate dtype: data type of the generated embeddings Returns: `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def step_index(self): return self._step_index def set_timesteps( self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, original_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, strength: int = 1.0, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. original_inference_steps (`int`, *optional*): The original number of inference steps, which will be used to generate a linearly-spaced timestep schedule (which is different from the standard `diffusers` implementation). We will then take `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`. """ # 0. Check inputs if num_inference_steps is None and timesteps is None: raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.") if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") # 1. Calculate the LCM original training/distillation timestep schedule. original_steps = ( original_inference_steps if original_inference_steps is not None else self.original_inference_steps ) if original_steps > self.ddpm_num_timesteps: raise ValueError( f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.ddpm_num_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.ddpm_num_timesteps} timesteps." ) # import ipdb # ipdb.set_trace() # LCM Timesteps Setting # The skipping step parameter k from the paper. k = self.ddpm_num_timesteps // original_steps # LCM Training/Distillation Steps Schedule # Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts). lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1 # 2. Calculate the LCM inference timestep schedule. if timesteps is not None: # 2.1 Handle custom timestep schedules. train_timesteps = set(lcm_origin_timesteps) non_train_timesteps = [] for i in range(1, len(timesteps)): if timesteps[i] >= timesteps[i - 1]: raise ValueError("`custom_timesteps` must be in descending order.") if timesteps[i] not in train_timesteps: non_train_timesteps.append(timesteps[i]) if timesteps[0] >= self.ddpm_num_timesteps: raise ValueError( f"`timesteps` must start before `self.config.train_timesteps`:" f" {self.ddpm_num_timesteps}." ) # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1 if strength == 1.0 and timesteps[0] != self.ddpm_num_timesteps - 1: logger.warning( f"The first timestep on the custom timestep schedule is {timesteps[0]}, not" f" `self.ddpm_num_timesteps - 1`: {self.ddpm_num_timesteps - 1}. You may get" f" unexpected results when using this timestep schedule." ) # Raise warning if custom timestep schedule contains timesteps not on original timestep schedule if non_train_timesteps: logger.warning( f"The custom timestep schedule contains the following timesteps which are not on the original" f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results" f" when using this timestep schedule." ) # Raise warning if custom timestep schedule is longer than original_steps if len(timesteps) > original_steps: logger.warning( f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the" f" the length of the timestep schedule used for training: {original_steps}. You may get some" f" unexpected results when using this timestep schedule." ) timesteps = np.array(timesteps, dtype=np.int64) self.num_inference_steps = len(timesteps) self.custom_timesteps = True # Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps) init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps) t_start = max(self.num_inference_steps - init_timestep, 0) timesteps = timesteps[t_start * self.order :] # TODO: also reset self.num_inference_steps? else: # 2.2 Create the "standard" LCM inference timestep schedule. if num_inference_steps > self.ddpm_num_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.ddpm_num_timesteps`:" f" {self.ddpm_num_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.ddpm_num_timesteps} timesteps." ) skipping_step = len(lcm_origin_timesteps) // num_inference_steps if skipping_step < 1: raise ValueError( f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}." ) self.num_inference_steps = num_inference_steps if num_inference_steps > original_steps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:" f" {original_steps} because the final timestep schedule will be a subset of the" f" `original_inference_steps`-sized initial timestep schedule." ) # LCM Inference Steps Schedule lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy() # Select (approximately) evenly spaced indices from lcm_origin_timesteps. inference_indices = np.linspace(0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False) inference_indices = np.floor(inference_indices).astype(np.int64) timesteps = lcm_origin_timesteps[inference_indices] self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long) self._step_index = None # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( self, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None: self.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = self.timesteps num_inference_steps = len(timesteps) else: self.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = self.timesteps return timesteps, num_inference_steps @torch.no_grad() def sample(self, S, batch_size, shape, conditioning=None, callback=None, normals_sequence=None, img_callback=None, verbose=True, x_T=None, guidance_scale=5., original_inference_steps=50, timesteps=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs ): if conditioning is not None: if isinstance(conditioning, dict): ctmp = conditioning[list(conditioning.keys())[0]] while isinstance(ctmp, list): ctmp = ctmp[0] cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") else: if conditioning.shape[0] != batch_size: print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") self.make_schedule(verbose=verbose) self.num_inference_steps = S # sampling if len(shape)==3: C, H, W = shape size = (batch_size, C, H, W) else: C, T = shape size = (batch_size, C, T) samples, intermediates = self.lcm_sampling(conditioning, size, x_T=x_T, guidance_scale=guidance_scale, original_inference_steps=original_inference_steps, timesteps=timesteps ) return samples, intermediates @torch.no_grad() def lcm_sampling(self, cond, shape, x_T=None, guidance_scale=1.,original_inference_steps=100,timesteps=None): device = self.model.betas.device timesteps, num_inference_steps = self.retrieve_timesteps( self.num_inference_steps, device, timesteps, original_inference_steps=original_inference_steps ) b = shape[0] if x_T is None: img = torch.randn(shape, device=device) else: img = x_T w = torch.tensor(guidance_scale - 1).repeat(b) w_embedding = self.get_guidance_scale_embedding(w, embedding_dim=256).to( device=device, dtype=img.dtype ) # import ipdb # ipdb.set_trace() # 8. LCM MultiStep Sampling Loop: num_warmup_steps = len(timesteps) - num_inference_steps self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): img = img.to(cond.dtype) ts = torch.full((b,), t, device=device, dtype=torch.long) # model prediction (v-prediction, eps, x) model_pred = self.model.apply_model(img, ts, cond,self.model.unet, w_cond=w_embedding) # compute the previous noisy sample x_t -> x_t-1 img, denoised = self.step(model_pred, t, img, return_dict=False) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps): progress_bar.update() return denoised, img # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep): if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) index_candidates = (self.timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) if len(index_candidates) > 1: step_index = index_candidates[1] else: step_index = index_candidates[0] self._step_index = step_index.item() def get_scalings_for_boundary_condition_discrete(self, timestep): self.sigma_data = 0.5 # Default: 0.5 scaled_timestep = timestep * self.timestep_scaling c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2) c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5 return c_skip, c_out @torch.no_grad() def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ): """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.step_index is None: self._init_step_index(timestep) # 1. get previous step value prev_step_index = self.step_index + 1 if prev_step_index < len(self.timesteps): prev_timestep = self.timesteps[prev_step_index] else: prev_timestep = timestep # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else torch.tensor(1.0) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev # 3. Get scalings for boundary conditions c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) # 4. Compute the predicted original sample x_0 based on the model parameterization if self.prediction_type == "epsilon": # noise-prediction predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() elif self.prediction_type == "sample": # x-prediction predicted_original_sample = model_output elif self.prediction_type == "v_prediction": # v-prediction predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output else: raise ValueError( f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for `LCMScheduler`." ) # 5. Denoise model output using boundary conditions denoised = c_out * predicted_original_sample + c_skip * sample # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference # Noise is not used on the final timestep of the timestep schedule. # This also means that noise is not used for one-step sampling. if self.step_index != self.num_inference_steps - 1: noise = torch.randn(model_output.shape, device=model_output.device) prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise else: prev_sample = denoised # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample, denoised) return prev_sample, denoised