# -*- coding: utf-8 -*- import torch from tqdm import tqdm from typing import Tuple, List, Union, Optional from diffusers.schedulers import DDIMScheduler __all__ = ["ddim_sample"] def ddim_sample(ddim_scheduler: DDIMScheduler, diffusion_model: torch.nn.Module, shape: Union[List[int], Tuple[int]], cond: torch.FloatTensor, steps: int, eta: float = 0.0, guidance_scale: float = 3.0, do_classifier_free_guidance: bool = True, generator: Optional[torch.Generator] = None, device: torch.device = "cuda:0", disable_prog: bool = True): assert steps > 0, f"{steps} must > 0." # init latents bsz = cond.shape[0] if do_classifier_free_guidance: bsz = bsz // 2 latents = torch.randn( (bsz, *shape), generator=generator, device=cond.device, dtype=cond.dtype, ) # scale the initial noise by the standard deviation required by the scheduler latents = latents * ddim_scheduler.init_noise_sigma # set timesteps ddim_scheduler.set_timesteps(steps) timesteps = ddim_scheduler.timesteps.to(device) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, and between [0, 1] extra_step_kwargs = { "eta": eta, "generator": generator } # reverse for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)): # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents] * 2) if do_classifier_free_guidance else latents ) # latent_model_input = scheduler.scale_model_input(latent_model_input, t) # predict the noise residual timestep_tensor = torch.tensor([t], dtype=torch.long, device=device) timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond) # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) # text_embeddings_for_guidance = encoder_hidden_states.chunk( # 2)[1] if do_classifier_free_guidance else encoder_hidden_states # compute the previous noisy sample x_t -> x_t-1 latents = ddim_scheduler.step( noise_pred, t, latents, **extra_step_kwargs ).prev_sample yield latents, t def karra_sample(): pass