ddim-dsc / cond_ddim_pipeline.py
lschmidt's picture
Rename cond-ddim-pipeline.py to cond_ddim_pipeline.py
87d6e87 verified
from typing import List, Optional, Tuple, Union
import torch
import inspect
from diffusers import DDIMScheduler, DiffusionPipeline, ImagePipelineOutput
class CondDDIMPipeline(DiffusionPipeline):
r"""
Pipeline for image generation.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Parameters:
unet ([`UNet2DModel`]):
A `UNet2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
[`DDPMScheduler`], or [`DDIMScheduler`].
"""
model_cpu_offload_seq = "unet"
def __init__(self, unet, scheduler):
super().__init__()
scheduler = DDIMScheduler.from_config(scheduler.config)
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
batch_size: int = 1,
image: torch.Tensor = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
num_images_per_cond: Optional[int] = 1,
eta: float = 0.0,
num_inference_steps: int = 50,
use_clipped_model_output: Optional[bool] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
The call function to the pipeline for generation.
Args:
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
image (torch.Tensor):
The LR image(s) to condition on.
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. A value of `0` corresponds to
DDIM and `1` corresponds to DDPM.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
use_clipped_model_output (`bool`, *optional*, defaults to `None`):
If `True` or `False`, see documentation for [`DDIMScheduler.step`]. If `None`, nothing is passed
downstream to the scheduler (use `None` for schedulers which don't support this argument).
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
"""
# create random tensor of shape latents
bs, _, height, width = image.shape
# check that generator is on device cuda
generator = torch.Generator(device=self._execution_device)
latents_shape = (bs * num_images_per_cond, self.unet.config.out_channels, height, width)
latents = torch.randn(latents_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype)
latents_dtype = next(self.unet.parameters()).dtype
# bring conditional img to device
image = torch.cat([image] * num_images_per_cond)
image = image.to(device=self.device, dtype=latents_dtype)
# set step values
self.scheduler.set_timesteps(num_inference_steps)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature.
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwargs = {}
if accepts_eta:
extra_kwargs["eta"] = eta
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
latents_input = torch.cat([latents, image], dim=1)
latents_input = self.scheduler.scale_model_input(latents_input, t)
noise_pred = self.unet(latents_input, t).sample
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
).prev_sample
image = latents.cpu().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)