|
""" |
|
--- |
|
title: U-Net for Stable Diffusion |
|
summary: > |
|
Annotated PyTorch implementation/tutorial of the U-Net in stable diffusion. |
|
--- |
|
|
|
# U-Net for [Stable Diffusion](../index.html) |
|
|
|
This implements the U-Net that |
|
gives $\epsilon_\text{cond}(x_t, c)$ |
|
|
|
We have kept to the model definition and naming unchanged from |
|
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion) |
|
so that we can load the checkpoints directly. |
|
""" |
|
|
|
import math |
|
from typing import List |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .unet_attention import SpatialTransformer |
|
|
|
|
|
class UNetModel(nn.Module): |
|
""" |
|
## U-Net model |
|
""" |
|
def __init__( |
|
self, |
|
*, |
|
in_channels: int, |
|
out_channels: int, |
|
channels: int, |
|
n_res_blocks: int, |
|
attention_levels: List[int], |
|
channel_multipliers: List[int], |
|
n_heads: int, |
|
tf_layers: int = 1, |
|
|
|
): |
|
""" |
|
:param in_channels: is the number of channels in the input feature map |
|
:param out_channels: is the number of channels in the output feature map |
|
:param channels: is the base channel count for the model |
|
:param n_res_blocks: number of residual blocks at each level |
|
:param attention_levels: are the levels at which attention should be performed |
|
:param channel_multipliers: are the multiplicative factors for number of channels for each level |
|
:param n_heads: the number of attention heads in the transformers |
|
""" |
|
super().__init__() |
|
self.channels = channels |
|
self.out_channels = out_channels |
|
|
|
|
|
|
|
levels = len(channel_multipliers) |
|
|
|
d_time_emb = channels * 4 |
|
self.time_embed = nn.Sequential( |
|
nn.Linear(channels, d_time_emb), |
|
nn.SiLU(), |
|
nn.Linear(d_time_emb, d_time_emb), |
|
) |
|
|
|
|
|
self.input_blocks = nn.ModuleList() |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.input_blocks.append( |
|
TimestepEmbedSequential(nn.Conv2d(in_channels, channels, 3, padding=1)) |
|
) |
|
|
|
input_block_channels = [channels] |
|
|
|
channels_list = [channels * m for m in channel_multipliers] |
|
|
|
for i in range(levels): |
|
|
|
for _ in range(n_res_blocks): |
|
|
|
|
|
layers = [ResBlock(channels, d_time_emb, out_channels=channels_list[i])] |
|
channels = channels_list[i] |
|
|
|
if i in attention_levels: |
|
layers.append( |
|
SpatialTransformer(channels, n_heads, tf_layers) |
|
) |
|
|
|
|
|
self.input_blocks.append(TimestepEmbedSequential(*layers)) |
|
input_block_channels.append(channels) |
|
|
|
if i != levels - 1: |
|
self.input_blocks.append(TimestepEmbedSequential(DownSample(channels))) |
|
input_block_channels.append(channels) |
|
|
|
|
|
self.middle_block = TimestepEmbedSequential( |
|
ResBlock(channels, d_time_emb), |
|
SpatialTransformer(channels, n_heads, tf_layers), |
|
ResBlock(channels, d_time_emb), |
|
) |
|
|
|
|
|
self.output_blocks = nn.ModuleList([]) |
|
|
|
for i in reversed(range(levels)): |
|
|
|
for j in range(n_res_blocks + 1): |
|
|
|
|
|
|
|
layers = [ |
|
ResBlock( |
|
channels + input_block_channels.pop(), |
|
d_time_emb, |
|
out_channels=channels_list[i] |
|
) |
|
] |
|
channels = channels_list[i] |
|
|
|
if i in attention_levels: |
|
layers.append( |
|
SpatialTransformer(channels, n_heads, tf_layers) |
|
) |
|
|
|
|
|
|
|
if i != 0 and j == n_res_blocks: |
|
layers.append(UpSample(channels)) |
|
|
|
self.output_blocks.append(TimestepEmbedSequential(*layers)) |
|
|
|
|
|
self.out = nn.Sequential( |
|
normalization(channels), |
|
nn.SiLU(), |
|
nn.Conv2d(channels, out_channels, 3, padding=1), |
|
) |
|
|
|
def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000): |
|
""" |
|
## Create sinusoidal time step embeddings |
|
|
|
:param time_steps: are the time steps of shape `[batch_size]` |
|
:param max_period: controls the minimum frequency of the embeddings. |
|
""" |
|
|
|
half = self.channels // 2 |
|
|
|
frequencies = torch.exp( |
|
-math.log(max_period) * |
|
torch.arange(start=0, end=half, dtype=torch.float32) / half |
|
).to(device=time_steps.device) |
|
|
|
args = time_steps[:, None].float() * frequencies[None] |
|
|
|
return torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
|
|
def forward(self, x: torch.Tensor, time_steps: torch.Tensor): |
|
""" |
|
:param x: is the input feature map of shape `[batch_size, channels, width, height]` |
|
:param time_steps: are the time steps of shape `[batch_size]` |
|
:param cond: conditioning of shape `[batch_size, n_cond, d_cond]` |
|
""" |
|
|
|
x_input_block = [] |
|
|
|
|
|
t_emb = self.time_step_embedding(time_steps) |
|
t_emb = self.time_embed(t_emb) |
|
|
|
|
|
for module in self.input_blocks: |
|
|
|
|
|
|
|
|
|
x = module(x, t_emb) |
|
x_input_block.append(x) |
|
|
|
x = self.middle_block(x, t_emb) |
|
|
|
for module in self.output_blocks: |
|
|
|
x = torch.cat([x, x_input_block.pop()], dim=1) |
|
|
|
x = module(x, t_emb) |
|
|
|
|
|
return self.out(x) |
|
|
|
|
|
class TimestepEmbedSequential(nn.Sequential): |
|
""" |
|
### Sequential block for modules with different inputs |
|
|
|
This sequential module can compose of different modules suck as `ResBlock`, |
|
`nn.Conv` and `SpatialTransformer` and calls them with the matching signatures |
|
""" |
|
def forward(self, x, t_emb, cond=None): |
|
for layer in self: |
|
if isinstance(layer, ResBlock): |
|
x = layer(x, t_emb) |
|
elif isinstance(layer, SpatialTransformer): |
|
x = layer(x) |
|
else: |
|
x = layer(x) |
|
return x |
|
|
|
|
|
class UpSample(nn.Module): |
|
""" |
|
### Up-sampling layer |
|
""" |
|
def __init__(self, channels: int): |
|
""" |
|
:param channels: is the number of channels |
|
""" |
|
super().__init__() |
|
|
|
self.conv = nn.Conv2d(channels, channels, 3, padding=1) |
|
|
|
def forward(self, x: torch.Tensor): |
|
""" |
|
:param x: is the input feature map with shape `[batch_size, channels, height, width]` |
|
""" |
|
|
|
x = F.interpolate(x, scale_factor=2, mode="nearest") |
|
|
|
return self.conv(x) |
|
|
|
|
|
class DownSample(nn.Module): |
|
""" |
|
## Down-sampling layer |
|
""" |
|
def __init__(self, channels: int): |
|
""" |
|
:param channels: is the number of channels |
|
""" |
|
super().__init__() |
|
|
|
self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1) |
|
|
|
def forward(self, x: torch.Tensor): |
|
""" |
|
:param x: is the input feature map with shape `[batch_size, channels, height, width]` |
|
""" |
|
|
|
return self.op(x) |
|
|
|
|
|
class ResBlock(nn.Module): |
|
""" |
|
## ResNet Block |
|
""" |
|
def __init__(self, channels: int, d_t_emb: int, *, out_channels=None): |
|
""" |
|
:param channels: the number of input channels |
|
:param d_t_emb: the size of timestep embeddings |
|
:param out_channels: is the number of out channels. defaults to `channels. |
|
""" |
|
super().__init__() |
|
|
|
if out_channels is None: |
|
out_channels = channels |
|
|
|
|
|
self.in_layers = nn.Sequential( |
|
normalization(channels), |
|
nn.SiLU(), |
|
nn.Conv2d(channels, out_channels, 3, padding=1), |
|
) |
|
|
|
|
|
self.emb_layers = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(d_t_emb, out_channels), |
|
) |
|
|
|
self.out_layers = nn.Sequential( |
|
normalization(out_channels), nn.SiLU(), nn.Dropout(0.), |
|
nn.Conv2d(out_channels, out_channels, 3, padding=1) |
|
) |
|
|
|
|
|
if out_channels == channels: |
|
self.skip_connection = nn.Identity() |
|
else: |
|
self.skip_connection = nn.Conv2d(channels, out_channels, 1) |
|
|
|
def forward(self, x: torch.Tensor, t_emb: torch.Tensor): |
|
""" |
|
:param x: is the input feature map with shape `[batch_size, channels, height, width]` |
|
:param t_emb: is the time step embeddings of shape `[batch_size, d_t_emb]` |
|
""" |
|
|
|
h = self.in_layers(x) |
|
|
|
t_emb = self.emb_layers(t_emb).type(h.dtype) |
|
|
|
h = h + t_emb[:, :, None, None] |
|
|
|
h = self.out_layers(h) |
|
|
|
return self.skip_connection(x) + h |
|
|
|
|
|
class GroupNorm32(nn.GroupNorm): |
|
""" |
|
### Group normalization with float32 casting |
|
""" |
|
def forward(self, x): |
|
return super().forward(x.float()).type(x.dtype) |
|
|
|
|
|
def normalization(channels): |
|
""" |
|
### Group normalization |
|
|
|
This is a helper function, with fixed number of groups.. |
|
""" |
|
return GroupNorm32(32, channels) |
|
|
|
|
|
def _test_time_embeddings(): |
|
""" |
|
Test sinusoidal time step embeddings |
|
""" |
|
import matplotlib.pyplot as plt |
|
|
|
plt.figure(figsize=(15, 5)) |
|
m = UNetModel( |
|
in_channels=1, |
|
out_channels=1, |
|
channels=320, |
|
n_res_blocks=1, |
|
attention_levels=[], |
|
channel_multipliers=[], |
|
n_heads=1, |
|
tf_layers=1, |
|
d_cond=1 |
|
) |
|
te = m.time_step_embedding(torch.arange(0, 1000)) |
|
plt.plot(np.arange(1000), te[:, [50, 100, 190, 260]].numpy()) |
|
plt.legend(["dim %d" % p for p in [50, 100, 190, 260]]) |
|
plt.title("Time embeddings") |
|
plt.show() |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
_test_time_embeddings() |
|
|