haoyuliu00's picture
Initial commit with cleaned history
bf8981a
"""
---
title: U-Net for Stable Diffusion
summary: >
Annotated PyTorch implementation/tutorial of the U-Net in stable diffusion.
---
# U-Net for [Stable Diffusion](../index.html)
This implements the U-Net that
gives $\epsilon_\text{cond}(x_t, c)$
We have kept to the model definition and naming unchanged from
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
so that we can load the checkpoints directly.
"""
import math
from typing import List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .unet_attention import SpatialTransformer
class UNetModel(nn.Module):
"""
## U-Net model
"""
def __init__(
self,
*,
in_channels: int,
out_channels: int,
channels: int,
n_res_blocks: int,
attention_levels: List[int],
channel_multipliers: List[int],
n_heads: int,
tf_layers: int = 1,
#d_cond: int = 768
):
"""
:param in_channels: is the number of channels in the input feature map
:param out_channels: is the number of channels in the output feature map
:param channels: is the base channel count for the model
:param n_res_blocks: number of residual blocks at each level
:param attention_levels: are the levels at which attention should be performed
:param channel_multipliers: are the multiplicative factors for number of channels for each level
:param n_heads: the number of attention heads in the transformers
"""
super().__init__()
self.channels = channels
self.out_channels = out_channels
#self.d_cond = d_cond
# Number of levels
levels = len(channel_multipliers)
# Size time embeddings
d_time_emb = channels * 4
self.time_embed = nn.Sequential(
nn.Linear(channels, d_time_emb),
nn.SiLU(),
nn.Linear(d_time_emb, d_time_emb),
)
# Input half of the U-Net
self.input_blocks = nn.ModuleList()
# Initial $3 \times 3$ convolution that maps the input to `channels`.
# The blocks are wrapped in `TimestepEmbedSequential` module because
# different modules have different forward function signatures;
# for example, convolution only accepts the feature map and
# residual blocks accept the feature map and time embedding.
# `TimestepEmbedSequential` calls them accordingly.
self.input_blocks.append(
TimestepEmbedSequential(nn.Conv2d(in_channels, channels, 3, padding=1))
)
# Number of channels at each block in the input half of U-Net
input_block_channels = [channels]
# Number of channels at each level
channels_list = [channels * m for m in channel_multipliers]
# Prepare levels
for i in range(levels):
# Add the residual blocks and attentions
for _ in range(n_res_blocks):
# Residual block maps from previous number of channels to the number of
# channels in the current level
layers = [ResBlock(channels, d_time_emb, out_channels=channels_list[i])]
channels = channels_list[i]
# Add transformer
if i in attention_levels:
layers.append(
SpatialTransformer(channels, n_heads, tf_layers)
)
# Add them to the input half of the U-Net and keep track of the number of channels of
# its output
self.input_blocks.append(TimestepEmbedSequential(*layers))
input_block_channels.append(channels)
# Down sample at all levels except last
if i != levels - 1:
self.input_blocks.append(TimestepEmbedSequential(DownSample(channels)))
input_block_channels.append(channels)
# The middle of the U-Net
self.middle_block = TimestepEmbedSequential(
ResBlock(channels, d_time_emb),
SpatialTransformer(channels, n_heads, tf_layers),
ResBlock(channels, d_time_emb),
)
# Second half of the U-Net
self.output_blocks = nn.ModuleList([])
# Prepare levels in reverse order
for i in reversed(range(levels)):
# Add the residual blocks and attentions
for j in range(n_res_blocks + 1):
# Residual block maps from previous number of channels plus the
# skip connections from the input half of U-Net to the number of
# channels in the current level.
layers = [
ResBlock(
channels + input_block_channels.pop(),
d_time_emb,
out_channels=channels_list[i]
)
]
channels = channels_list[i]
# Add transformer
if i in attention_levels:
layers.append(
SpatialTransformer(channels, n_heads, tf_layers)
)
# Up-sample at every level after last residual block
# except the last one.
# Note that we are iterating in reverse; i.e. `i == 0` is the last.
if i != 0 and j == n_res_blocks:
layers.append(UpSample(channels))
# Add to the output half of the U-Net
self.output_blocks.append(TimestepEmbedSequential(*layers))
# Final normalization and $3 \times 3$ convolution
self.out = nn.Sequential(
normalization(channels),
nn.SiLU(),
nn.Conv2d(channels, out_channels, 3, padding=1),
)
def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000):
"""
## Create sinusoidal time step embeddings
:param time_steps: are the time steps of shape `[batch_size]`
:param max_period: controls the minimum frequency of the embeddings.
"""
# $\frac{c}{2}$; half the channels are sin and the other half is cos,
half = self.channels // 2
# $\frac{1}{10000^{\frac{2i}{c}}}$
frequencies = torch.exp(
-math.log(max_period) *
torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=time_steps.device)
# $\frac{t}{10000^{\frac{2i}{c}}}$
args = time_steps[:, None].float() * frequencies[None]
# $\cos\Bigg(\frac{t}{10000^{\frac{2i}{c}}}\Bigg)$ and $\sin\Bigg(\frac{t}{10000^{\frac{2i}{c}}}\Bigg)$
return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
def forward(self, x: torch.Tensor, time_steps: torch.Tensor):
"""
:param x: is the input feature map of shape `[batch_size, channels, width, height]`
:param time_steps: are the time steps of shape `[batch_size]`
:param cond: conditioning of shape `[batch_size, n_cond, d_cond]`
"""
# To store the input half outputs for skip connections
x_input_block = []
# Get time step embeddings
t_emb = self.time_step_embedding(time_steps)
t_emb = self.time_embed(t_emb)
# Input half of the U-Net
for module in self.input_blocks:
##########################
#print("x:", x.dtype,"t_emb:",t_emb.dtype)
##########################
#x = x.to(torch.float16)
x = module(x, t_emb)
x_input_block.append(x)
# Middle of the U-Net
x = self.middle_block(x, t_emb)
# Output half of the U-Net
for module in self.output_blocks:
# print(x.size(), 'a')
x = torch.cat([x, x_input_block.pop()], dim=1)
# print(x.size(), 'b')
x = module(x, t_emb)
# Final normalization and $3 \times 3$ convolution
return self.out(x)
class TimestepEmbedSequential(nn.Sequential):
"""
### Sequential block for modules with different inputs
This sequential module can compose of different modules suck as `ResBlock`,
`nn.Conv` and `SpatialTransformer` and calls them with the matching signatures
"""
def forward(self, x, t_emb, cond=None):
for layer in self:
if isinstance(layer, ResBlock):
x = layer(x, t_emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x)
else:
x = layer(x)
return x
class UpSample(nn.Module):
"""
### Up-sampling layer
"""
def __init__(self, channels: int):
"""
:param channels: is the number of channels
"""
super().__init__()
# $3 \times 3$ convolution mapping
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, x: torch.Tensor):
"""
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
"""
# Up-sample by a factor of $2$
x = F.interpolate(x, scale_factor=2, mode="nearest")
# Apply convolution
return self.conv(x)
class DownSample(nn.Module):
"""
## Down-sampling layer
"""
def __init__(self, channels: int):
"""
:param channels: is the number of channels
"""
super().__init__()
# $3 \times 3$ convolution with stride length of $2$ to down-sample by a factor of $2$
self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
def forward(self, x: torch.Tensor):
"""
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
"""
# Apply convolution
return self.op(x)
class ResBlock(nn.Module):
"""
## ResNet Block
"""
def __init__(self, channels: int, d_t_emb: int, *, out_channels=None):
"""
:param channels: the number of input channels
:param d_t_emb: the size of timestep embeddings
:param out_channels: is the number of out channels. defaults to `channels.
"""
super().__init__()
# `out_channels` not specified
if out_channels is None:
out_channels = channels
# First normalization and convolution
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
nn.Conv2d(channels, out_channels, 3, padding=1),
)
# Time step embeddings
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(d_t_emb, out_channels),
)
# Final convolution layer
self.out_layers = nn.Sequential(
normalization(out_channels), nn.SiLU(), nn.Dropout(0.),
nn.Conv2d(out_channels, out_channels, 3, padding=1)
)
# `channels` to `out_channels` mapping layer for residual connection
if out_channels == channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = nn.Conv2d(channels, out_channels, 1)
def forward(self, x: torch.Tensor, t_emb: torch.Tensor):
"""
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
:param t_emb: is the time step embeddings of shape `[batch_size, d_t_emb]`
"""
# Initial convolution
h = self.in_layers(x)
# Time step embeddings
t_emb = self.emb_layers(t_emb).type(h.dtype)
# Add time step embeddings
h = h + t_emb[:, :, None, None]
# Final convolution
h = self.out_layers(h)
# Add skip connection
return self.skip_connection(x) + h
class GroupNorm32(nn.GroupNorm):
"""
### Group normalization with float32 casting
"""
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def normalization(channels):
"""
### Group normalization
This is a helper function, with fixed number of groups..
"""
return GroupNorm32(32, channels)
def _test_time_embeddings():
"""
Test sinusoidal time step embeddings
"""
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 5))
m = UNetModel(
in_channels=1,
out_channels=1,
channels=320,
n_res_blocks=1,
attention_levels=[],
channel_multipliers=[],
n_heads=1,
tf_layers=1,
d_cond=1
)
te = m.time_step_embedding(torch.arange(0, 1000))
plt.plot(np.arange(1000), te[:, [50, 100, 190, 260]].numpy())
plt.legend(["dim %d" % p for p in [50, 100, 190, 260]])
plt.title("Time embeddings")
plt.show()
#
if __name__ == '__main__':
_test_time_embeddings()