Diffusion Course documentation

Diffusion Models from Scratch

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Open In Colab

Diffusion Models from Scratch

Sometimes it is helpful to consider the simplest possible version of something to better understand how it works. We’re going to try that in this notebook, beginning with a ‘toy’ diffusion model to see how the different pieces work, and then examining how they differ from a more complex implementation.

We will look at

  • The corruption process (adding noise to data)
  • What a UNet is, and how to implement an extremely minimal one from scratch
  • Diffusion model training
  • Sampling theory

Then we’ll compare our versions with the diffusers DDPM implementation, exploring

  • Improvements over our mini UNet
  • The DDPM noise schedule
  • Differences in training objective
  • Timestep conditioning
  • Sampling approaches

This notebook is fairly in-depth, and can safely be skipped if you’re not excited about a from-scratch deep dive!

It is also worth noting that most of the code here is for illustrative purposes, and I wouldn’t recommend directly adopting any of it for your own work (unless you’re just trying improve on the examples shown here for learning purposes).

Setup and Imports:

>>> %pip install -q diffusers
     |████████████████████████████████| 255 kB 16.0 MB/s 
     |████████████████████████████████| 163 kB 53.9 MB/s 
[?25h
>>> import torch
>>> import torchvision
>>> from torch import nn
>>> from torch.nn import functional as F
>>> from torch.utils.data import DataLoader
>>> from diffusers import DDPMScheduler, UNet2DModel
>>> from matplotlib import pyplot as plt

>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> print(f"Using device: {device}")
Using device: cuda

The Data

Here we’re going to test things with a very small dataset: mnist. If you’d like to give the model a slightly harder challenge without changing anything else, torchvision.datasets.FashionMNIST should work as a drop-in replacement.

>>> dataset = torchvision.datasets.MNIST(
...     root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor()
... )
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
>>> x, y = next(iter(train_dataloader))
>>> print("Input shape:", x.shape)
>>> print("Labels:", y)
>>> plt.imshow(torchvision.utils.make_grid(x)[0], cmap="Greys")
Input shape: torch.Size([8, 1, 28, 28])
Labels: tensor([1, 9, 7, 3, 5, 2, 1, 4])

Each image is a greyscale 28px by 28px drawing of a digit, with values ranging from 0 to 1.

The Corruption Process

Pretend you haven’t read any diffusion model papers, but you know the process involves adding noise. How would you do it?

We probably want an easy way to control the amount of corruption. So what if we take in a parameter for the amount of noise to add, and then we do:

noise = torch.rand_like(x)

noisy_x = (1-amount)*x + amount*noise

If amount = 0, we get back the input without any changes. If amount gets up to 1, we get back noise with no trace of the input x. By mixing the input with noise this way, we keep the output in the same range (0 to 1).

We can implement this fairly easily (just watch the shapes so you don’t get burnt by broadcasting rules):

def corrupt(x, amount):
    """Corrupt the input `x` by mixing it with noise according to `amount`"""
    noise = torch.rand_like(x)
    amount = amount.view(-1, 1, 1, 1)  # Sort shape so broadcasting works
    return x * (1 - amount) + noise * amount

And looking at the results visually to see that it works as expected:

>>> # Plotting the input data
>>> fig, axs = plt.subplots(2, 1, figsize=(12, 5))
>>> axs[0].set_title("Input data")
>>> axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap="Greys")

>>> # Adding noise
>>> amount = torch.linspace(0, 1, x.shape[0])  # Left to right -> more corruption
>>> noised_x = corrupt(x, amount)

>>> # Plotting the noised version
>>> axs[1].set_title("Corrupted data (-- amount increases -->)")
>>> axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap="Greys")

As noise amount approaches one, our data begins to look like pure random noise. But for most noise amounts, you can guess the digit fairly well. Do you think this is optimal?

The Model

We’d like a model that takes in a 28px noisy images and outputs a prediction of the same shape. A popular choice here is an architecture called a UNet. Originally invented for segmentation tasks in medical imagery, a UNet consists of a ‘constricting path’ through which data is compressed down and an ‘expanding path’ through which it expands back up to the original dimension (similar to an autoencoder) but also features skip connections that allow for information and gradients to flow across at different levels.

Some UNets feature complex blocks at each stage, but for this toy demo we’ll build a minimal example that takes in a one-channel image and passes it through three convolutional layers on the down path (the down_layers in the diagram and code) and three on the up path, with skip connections between the down and up layers. We’ll use max pooling for downsampling and nn.Upsample for upsampling rather than relying on learnable layers like more complex UNets. Here is the rough architecture showing the number of channels in the output of each layer:

unet_diag.png

This is what that looks like in code:

class BasicUNet(nn.Module):
    """A minimal UNet implementation."""

    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.down_layers = torch.nn.ModuleList(
            [
                nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
                nn.Conv2d(32, 64, kernel_size=5, padding=2),
                nn.Conv2d(64, 64, kernel_size=5, padding=2),
            ]
        )
        self.up_layers = torch.nn.ModuleList(
            [
                nn.Conv2d(64, 64, kernel_size=5, padding=2),
                nn.Conv2d(64, 32, kernel_size=5, padding=2),
                nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
            ]
        )
        self.act = nn.SiLU()  # The activation function
        self.downscale = nn.MaxPool2d(2)
        self.upscale = nn.Upsample(scale_factor=2)

    def forward(self, x):
        h = []
        for i, l in enumerate(self.down_layers):
            x = self.act(l(x))  # Through the layer and the activation function
            if i < 2:  # For all but the third (final) down layer:
                h.append(x)  # Storing output for skip connection
                x = self.downscale(x)  # Downscale ready for the next layer

        for i, l in enumerate(self.up_layers):
            if i > 0:  # For all except the first up layer
                x = self.upscale(x)  # Upscale
                x += h.pop()  # Fetching stored output (skip connection)
            x = self.act(l(x))  # Through the layer and the activation function

        return x

We can verify that the output shape is the same as the input, as we expect:

net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape

This network has just over 300,000 parameters:

sum([p.numel() for p in net.parameters()])

You can explore changing the number of channels in each layer or swapping in different architectures if you want.

Training the network

So what should the model do, exactly? Again, there are various takes on this but for this demo let’s pick a simple framing: given a corrupted input noisy_x the model should output its best guess for what the original x looks like. We will compare this to the actual value via the mean squared error.

We can now have a go at training the network.

  • Get a batch of data
  • Corrupt it by random amounts
  • Feed it through the model
  • Compare the model predictions with the clean images to calculate our loss
  • Update the model’s parameters accordingly.

Feel free to modify this and see if you can get it working better!

>>> # Dataloader (you can mess with batch size)
>>> batch_size = 128
>>> train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

>>> # How many runs through the data should we do?
>>> n_epochs = 3

>>> # Create the network
>>> net = BasicUNet()
>>> net.to(device)

>>> # Our loss function
>>> loss_fn = nn.MSELoss()

>>> # The optimizer
>>> opt = torch.optim.Adam(net.parameters(), lr=1e-3)

>>> # Keeping a record of the losses for later viewing
>>> losses = []

>>> # The training loop
>>> for epoch in range(n_epochs):

...     for x, y in train_dataloader:

...         # Get some data and prepare the corrupted version
...         x = x.to(device)  # Data on the GPU
...         noise_amount = torch.rand(x.shape[0]).to(device)  # Pick random noise amounts
...         noisy_x = corrupt(x, noise_amount)  # Create our noisy x

...         # Get the model prediction
...         pred = net(noisy_x)

...         # Calculate the loss
...         loss = loss_fn(pred, x)  # How close is the output to the true 'clean' x?

...         # Backprop and update the params:
...         opt.zero_grad()
...         loss.backward()
...         opt.step()

...         # Store the loss for later
...         losses.append(loss.item())

...     # Print our the average of the loss values for this epoch:
...     avg_loss = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
...     print(f"Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}")

>>> # View the loss curve
>>> plt.plot(losses)
>>> plt.ylim(0, 0.1)
Finished epoch 0. Average loss for this epoch: 0.026736
Finished epoch 1. Average loss for this epoch: 0.020692
Finished epoch 2. Average loss for this epoch: 0.018887

We can try to see what the model predictions look like by grabbing a batch of data, corrupting it by different amounts and then seeing the models predictions:

>>> # @markdown Visualizing model predictions on noisy inputs:

>>> # Fetch some data
>>> x, y = next(iter(train_dataloader))
>>> x = x[:8]  # Only using the first 8 for easy plotting

>>> # Corrupt with a range of amounts
>>> amount = torch.linspace(0, 1, x.shape[0])  # Left to right -> more corruption
>>> noised_x = corrupt(x, amount)

>>> # Get the model predictions
>>> with torch.no_grad():
...     preds = net(noised_x.to(device)).detach().cpu()

>>> # Plot
>>> fig, axs = plt.subplots(3, 1, figsize=(12, 7))
>>> axs[0].set_title("Input data")
>>> axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap="Greys")
>>> axs[1].set_title("Corrupted data")
>>> axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap="Greys")
>>> axs[2].set_title("Network Predictions")
>>> axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap="Greys")

You can see that for the lower amounts the predictions are pretty good! But as the level gets very high there is less for the model to work with, and by the time we get to amount=1 it outputs a blurry mess close to the mean of the dataset to try and hedge its bets on what the output might look like…

Sampling

If our predictions at high noise levels aren’t very good, how do we generate images?

Well, what if we start from random noise, look at the model predictions but then only move a small amount towards that prediction - say, 20% of the way there. Now we have a very noisy image in which there is perhaps a hint of structure, which we can feed into the model to get a new prediction. The hope is that this new prediction is slightly better than the first one (since our starting point is slightly less noisy) and so we can take another small step with this new, better prediction.

Repeat a few times and (if all goes well) we get an image out! Here is that process illustrated over just 5 steps, visualizing the model input (left) and the predicted denoised images (right) at each stage. Note that even though the model predicts the denoised image even at step 1, we only move x part of the way there. Over a few steps the structures appear and are refined, until we get our final outputs.

>>> # @markdown Sampling strategy: Break the process into 5 steps and move 1/5'th of the way there each time:
>>> n_steps = 5
>>> x = torch.rand(8, 1, 28, 28).to(device)  # Start from random
>>> step_history = [x.detach().cpu()]
>>> pred_output_history = []

>>> for i in range(n_steps):
...     with torch.no_grad():  # No need to track gradients during inference
...         pred = net(x)  # Predict the denoised x0
...     pred_output_history.append(pred.detach().cpu())  # Store model output for plotting
...     mix_factor = 1 / (n_steps - i)  # How much we move towards the prediction
...     x = x * (1 - mix_factor) + pred * mix_factor  # Move part of the way there
...     step_history.append(x.detach().cpu())  # Store step for plotting

>>> fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
>>> axs[0, 0].set_title("x (model input)")
>>> axs[0, 1].set_title("model prediction")
>>> for i in range(n_steps):
...     axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap="Greys")
...     axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap="Greys")

We can split the process up into more steps, and hope for better images that way:

>>> # @markdown Showing more results, using 40 sampling steps
>>> n_steps = 40
>>> x = torch.rand(64, 1, 28, 28).to(device)
>>> for i in range(n_steps):
...     noise_amount = torch.ones((x.shape[0],)).to(device) * (1 - (i / n_steps))  # Starting high going low
...     with torch.no_grad():
...         pred = net(x)
...     mix_factor = 1 / (n_steps - i)
...     x = x * (1 - mix_factor) + pred * mix_factor
>>> fig, ax = plt.subplots(1, 1, figsize=(12, 12))
>>> ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap="Greys")

Not great, but there are some recognizable digits there! You can experiment with training for longer (say, 10 or 20 epochs) and tweaking model config, learning rate, optimizer and so on. Also, don’t forget that fashionMNIST is a one-line replacement if you want a slightly harder dataset to try.

Comparison To DDPM

In this section we’ll take a look at how our toy implementation differs from the approach used in the other notebook (Introduction to Diffusers), which is based on the DDPM paper.

We’ll see that

  • The diffusers UNet2DModel is a bit more advanced than our BasicUNet
  • The corruption process in handled differently
  • The training objective is different, involving predicting the noise rather than the denoised image
  • The model is conditioned on the amount of noise present via timestep conditioning, where t is passed as an additional argument to the forward method.
  • There are a number of different sampling strategies available, which should work better than our simplistic version above.

There have been a number of improvements suggested since the DDPM paper came out, but this example is hopefully instructive as to the different available design decisions. Once you’ve read through this, you may enjoy diving into the paper ‘Elucidating the Design Space of Diffusion-Based Generative Models’ which explores all of these components in some detail and makes new recommendations for how to get the best performance.

If all of this is too technical or intimidating, don’t worry! Feel free to skip the rest of this notebook or save it for a rainy day.

The UNet

The diffusers UNet2DModel model has a number of improvements over our basic UNet above:

  • GroupNorm applies group normalization to the inputs of each block
  • Dropout layers for smoother training
  • Multiple resnet layers per block (if layers_per_block isn’t set to 1)
  • Attention (usually used only at lower resolution blocks)
  • Conditioning on the timestep.
  • Downsampling and upsampling blocks with learnable parameters

Let’s create and inspect a UNet2DModel:

>>> model = UNet2DModel(
...     sample_size=28,  # the target image resolution
...     in_channels=1,  # the number of input channels, 3 for RGB images
...     out_channels=1,  # the number of output channels
...     layers_per_block=2,  # how many ResNet layers to use per UNet block
...     block_out_channels=(32, 64, 64),  # Roughly matching our basic unet example
...     down_block_types=(
...         "DownBlock2D",  # a regular ResNet downsampling block
...         "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
...         "AttnDownBlock2D",
...     ),
...     up_block_types=(
...         "AttnUpBlock2D",
...         "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
...         "UpBlock2D",  # a regular ResNet upsampling block
...     ),
... )
>>> print(model)
UNet2DModel(
  (conv_in): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=32, out_features=128, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=128, out_features=128, bias=True)
  )
  (down_blocks): ModuleList(
    (0): DownBlock2D(
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
          (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
          (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
      (downsamplers): ModuleList(
        (0): Downsample2D(
          (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
      )
    )
    (1): AttnDownBlock2D(
      (attentions): ModuleList(
        (0): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (1): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
      )
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
          (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
      (downsamplers): ModuleList(
        (0): Downsample2D(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
      )
    )
    (2): AttnDownBlock2D(
      (attentions): ModuleList(
        (0): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (1): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
      )
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
    )
  )
  (up_blocks): ModuleList(
    (0): AttnUpBlock2D(
      (attentions): ModuleList(
        (0): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (1): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (2): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
      )
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (upsamplers): ModuleList(
        (0): Upsample2D(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (1): AttnUpBlock2D(
      (attentions): ModuleList(
        (0): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (1): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (2): AttentionBlock(
          (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
      )
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): ResnetBlock2D(
          (norm1): GroupNorm(32, 96, eps=1e-05, affine=True)
          (conv1): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (upsamplers): ModuleList(
        (0): Upsample2D(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (2): UpBlock2D(
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 96, eps=1e-05, affine=True)
          (conv1): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): ResnetBlock2D(
          (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
  )
  (mid_block): UNetMidBlock2D(
    (attentions): ModuleList(
      (0): AttentionBlock(
        (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (key): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (proj_attn): Linear(in_features=64, out_features=64, bias=True)
      )
    )
    (resnets): ModuleList(
      (0): ResnetBlock2D(
        (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
        (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (nonlinearity): SiLU()
      )
      (1): ResnetBlock2D(
        (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
        (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (nonlinearity): SiLU()
      )
    )
  )
  (conv_norm_out): GroupNorm(32, 32, eps=1e-05, affine=True)
  (conv_act): SiLU()
  (conv_out): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

As you can see, a little more going on! It also has significantly more parameters than our BasicUNet:

sum([p.numel() for p in model.parameters()])  # 1.7M vs the ~309k parameters of the BasicUNet

We can replicate the training shown above using this model in place of our original one. We need to pass both x and timestep to the model (here I always pass t=0 to show that it works without this timestep conditioning and to keep the sampling code easy, but you can also try feeding in (amount*1000) to get a timestep equivalent from the corruption amount). Lines changed are shown with #<<< if you want to inspect the code.

>>> # @markdown Trying UNet2DModel instead of BasicUNet:

>>> # Dataloader (you can mess with batch size)
>>> batch_size = 128
>>> train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

>>> # How many runs through the data should we do?
>>> n_epochs = 3

>>> # Create the network
>>> net = UNet2DModel(
...     sample_size=28,  # the target image resolution
...     in_channels=1,  # the number of input channels, 3 for RGB images
...     out_channels=1,  # the number of output channels
...     layers_per_block=2,  # how many ResNet layers to use per UNet block
...     block_out_channels=(32, 64, 64),  # Roughly matching our basic unet example
...     down_block_types=(
...         "DownBlock2D",  # a regular ResNet downsampling block
...         "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
...         "AttnDownBlock2D",
...     ),
...     up_block_types=(
...         "AttnUpBlock2D",
...         "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
...         "UpBlock2D",  # a regular ResNet upsampling block
...     ),
... )  # <<<
>>> net.to(device)

>>> # Our loss finction
>>> loss_fn = nn.MSELoss()

>>> # The optimizer
>>> opt = torch.optim.Adam(net.parameters(), lr=1e-3)

>>> # Keeping a record of the losses for later viewing
>>> losses = []

>>> # The training loop
>>> for epoch in range(n_epochs):

...     for x, y in train_dataloader:

...         # Get some data and prepare the corrupted version
...         x = x.to(device)  # Data on the GPU
...         noise_amount = torch.rand(x.shape[0]).to(device)  # Pick random noise amounts
...         noisy_x = corrupt(x, noise_amount)  # Create our noisy x

...         # Get the model prediction
...         pred = net(noisy_x, 0).sample  # <<< Using timestep 0 always, adding .sample

...         # Calculate the loss
...         loss = loss_fn(pred, x)  # How close is the output to the true 'clean' x?

...         # Backprop and update the params:
...         opt.zero_grad()
...         loss.backward()
...         opt.step()

...         # Store the loss for later
...         losses.append(loss.item())

...     # Print our the average of the loss values for this epoch:
...     avg_loss = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
...     print(f"Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}")

>>> # Plot losses and some samples
>>> fig, axs = plt.subplots(1, 2, figsize=(12, 5))

>>> # Losses
>>> axs[0].plot(losses)
>>> axs[0].set_ylim(0, 0.1)
>>> axs[0].set_title("Loss over time")

>>> # Samples
>>> n_steps = 40
>>> x = torch.rand(64, 1, 28, 28).to(device)
>>> for i in range(n_steps):
...     noise_amount = torch.ones((x.shape[0],)).to(device) * (1 - (i / n_steps))  # Starting high going low
...     with torch.no_grad():
...         pred = net(x, 0).sample
...     mix_factor = 1 / (n_steps - i)
...     x = x * (1 - mix_factor) + pred * mix_factor

>>> axs[1].imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap="Greys")
>>> axs[1].set_title("Generated Samples")
Finished epoch 0. Average loss for this epoch: 0.018925
Finished epoch 1. Average loss for this epoch: 0.012785
Finished epoch 2. Average loss for this epoch: 0.011694

This looks quite a bit better than our first set of results! You can explore tweaking the unet configuration or training for longer to get even better performance.

The Corruption Process

The DDPM paper describes a corruption process that adds a small amount of noise for every ‘timestep’. Given $x_{t-1}$ for some timestep, we can get the next (slightly more noisy) version $x_t$ with:

$q(\mathbf{x}t \vert \mathbf{x}{t-1}) = \mathcal{N}(\mathbf{x}t; \sqrt{1 - \beta_t} \mathbf{x}{t-1}, \betat\mathbf{I}) \quad q(\mathbf{x}{1:T} \vert \mathbf{x}0) = \prod^T{t=1} q(\mathbf{x}t \vert \mathbf{x}{t-1})$

That is, we take $x{t-1}$, scale it by $\sqrt{1 - \beta_t}$ and add noise scaled by $\beta_t$. This $\beta$ is defined for every t according to some schedule, and determines how much noise is added per timestep. Now, we don’t necessarily want to do this operation 500 times to get $x{500}$ so we have another formula to get $x_t$ for any t given $x_0$:

$\begin{aligned} q(\mathbf{x}t \vert \mathbf{x}_0) &= \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, \sqrt{(1 - \bar{\alpha}_t)} \mathbf{I}) \end{aligned}$ where $\bar{\alpha}_t = \prod{i=1}^T \alpha_i$ and $\alpha_i = 1-\beta_i$

The maths notation always looks scary! Luckily the scheduler handles all that for us (uncomment the next cell to check out the code). We can plot $\sqrt{\bar{\alpha}_t}$ (labelled as sqrt_alpha_prod) and $\sqrt{(1 - \bar{\alpha}_t)}$ (labelled as sqrt_one_minus_alpha_prod) to view how the input (x) and the noise are scaled and mixed across different timesteps:

# ??noise_scheduler.add_noise
>>> noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
>>> plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
>>> plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
>>> plt.legend(fontsize="x-large")

Initially, the noisy x is mostly x (sqrt_alpha_prod ~= 1) but over time the contribution of x drops and the noise component increases. Unlike our linear mix of x and noise according to amount, this one gets noisy relatively quickly. We can visualize this on some data:

>>> # @markdown visualize the DDPM noising process for different timesteps:

>>> # Noise a batch of images to view the effect
>>> fig, axs = plt.subplots(3, 1, figsize=(16, 10))
>>> xb, yb = next(iter(train_dataloader))
>>> xb = xb.to(device)[:8]
>>> xb = xb * 2.0 - 1.0  # Map to (-1, 1)
>>> print("X shape", xb.shape)

>>> # Show clean inputs
>>> axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().cpu(), cmap="Greys")
>>> axs[0].set_title("Clean X")

>>> # Add noise with scheduler
>>> timesteps = torch.linspace(0, 999, 8).long().to(device)
>>> noise = torch.randn_like(xb)  # << NB: randn not rand
>>> noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
>>> print("Noisy X shape", noisy_xb.shape)

>>> # Show noisy version (with and without clipping)
>>> axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1, 1), cmap="Greys")
>>> axs[1].set_title("Noisy X (clipped to (-1, 1)")
>>> axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu(), cmap="Greys")
>>> axs[2].set_title("Noisy X")
X shape torch.Size([8, 1, 28, 28])
Noisy X shape torch.Size([8, 1, 28, 28])

Another dynamic at play: the DDPM version adds noise drawn from a Gaussian distribution (mean 0, s.d. 1 from torch.randn) rather than the uniform noise between 0 and 1 (from torch.rand) we used in our original corrupt function. In general, it makes sense to normalize the training data as well. In the other notebook, you’ll see Normalize(0.5, 0.5) in the list of transforms, which maps the image data form (0, 1) to (-1, 1) and is ‘good enough’ for our purposes. We didn’t do that for this notebook, but the visualization cell above adds it in for more accurate scaling and visualization.

Training Objective

In our toy example, we had the model try to predict the denoised image. In DDPM and many other diffusion model implementations, the model predicts the noise used in the corruption process (before scaling, so unit variance noise). In code, it looks something like:

noise = torch.randn_like(xb) # << NB: randn not rand
noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
model_prediction = model(noisy_x, timesteps).sample
loss = mse_loss(model_prediction, noise) # noise as the target

You may think that predicting the noise (from which we can derive what the denoised image looks like) is equivalent to just predicting the denoised image directly. So why favour one over the other - is it just for mathematical convenience?

It turns out there’s another subtlety here. We compute the loss across different (randomly chosen) timesteps during training. These different objectives will lead to different ‘implicit weighting’ of these losses, where predicting the noise puts more weight on lower noise levels. You can pick more complex objectives to change this ‘implicit loss weighting’. Or perhaps you choose a noise schedule that will result in more examples at a higher noise level. Perhaps you have the model predict a ‘velocity’ v which we define as being a combination of both the image and the noise dependent on the noise level (see ‘PROGRESSIVE DISTILLATION FOR FAST SAMPLING OF DIFFUSION MODELS’). Perhaps you have the model predict the noise but then scale the loss by some factor dependent on the amount of noise based on a bit of theory (see ‘Perception Prioritized Training of Diffusion Models’) or based on experiments trying to see what noise levels are most informative to the model (see ‘Elucidating the Design Space of Diffusion-Based Generative Models’). TL;DR: choosing the objective has an effect on model performance, and research in ongoing into what the ‘best’ option is.

At the moment, predicting the noise (epsilon or eps you’ll see in some places) is the favoured approach but over time we will likely see other objectives supported in the library and used in different situations.

Timestep Conditioning

The UNet2DModel takes in both x and timestep. The latter is turned into an embedding and fed into the model in a number of places.

The theory behind this is that by giving the model information about what the noise level is, it can better perform its task. While it is possible to train a model without this timestep conditioning, it does seem to help performance in some cases and most implementations include it, at least in the current literature.

Sampling

Given a model that estimates the noise present in a noisy input (or predicts the denoised version), how do we produce new images?

We could feed in pure noise, and hope that the model predicts a good image as the denoised version in one step. However, as we saw in the experiments above, this doesn’t usually work well. So, instead, we take a number of smaller steps based on the model prediction, iteratively removing a little bit of the noise at a time.

Exactly how we take these steps depends on the sampling method used. We won’t go into the theory too deeply, but some key design questions are:

  • How large of a step should you take? In other words, what ‘noise schedule’ should you follow?
  • Do you use only the model’s current prediction to inform the update step (like DDPM, DDIM and many others)? Do you evaluate the model several times to estimate higher-order gradients for a larger, more accurate step (higher-order methods and some discrete ODE solvers)? Or do you keep a history of past predictions to try and better inform the current update step (linear multi-step and ancestral samplers)?
  • Do you add in additional noise (sometimes called churn) to add more stochasticity (randomness) to the sampling process, or do you keep it completely deterministic? Many samplers control this with a parameter (such as ‘eta’ for DDIM samplers) so that the user can choose.

Research on sampling methods for diffusion models is rapidly evolving, and more and more methods for finding good solutions in fewer steps are being proposed. The brave and curious might find it interesting to browse through the code of the different implementations available in the diffusers library here or check out the docs which often link to the relevant papers.

Conclusions

Hopefully this has been a helpful way to look at diffusion models from a slightly different angle.

This notebook was written for this Hugging Face course by Jonathan Whitaker, and overlaps with a version included in his own course, ‘The Generative Landscape’. Check that out if you’d like to see this basic example extended with noise and class conditioning. Questions or bugs can be communicated through GitHub issues or via Discord. You are also welcome to reach out via Twitter @johnowhitaker.