File size: 11,888 Bytes
bf8981a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 |
"""
---
title: Transformer for Stable Diffusion U-Net
summary: >
Annotated PyTorch implementation/tutorial of the transformer
for U-Net in stable diffusion.
---
# Transformer for Stable Diffusion [U-Net](unet.html)
This implements the transformer module used in [U-Net](unet.html) that
gives $\epsilon_\text{cond}(x_t, c)$
We have kept to the model definition and naming unchanged from
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
so that we can load the checkpoints directly.
"""
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
class SpatialTransformer(nn.Module):
"""
## Spatial Transformer
"""
def __init__(self, channels: int, n_heads: int, n_layers: int):
"""
:param channels: is the number of channels in the feature map
:param n_heads: is the number of attention heads
:param n_layers: is the number of transformer layers
:param d_cond: is the size of the conditional embedding
"""
super().__init__()
# Initial group normalization
self.norm = torch.nn.GroupNorm(
num_groups=32, num_channels=channels, eps=1e-6, affine=True
)
# Initial $1 \times 1$ convolution
self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
# Transformer layers
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
channels, n_heads, channels // n_heads
) for _ in range(n_layers)
]
)
# Final $1 \times 1$ convolution
self.proj_out = nn.Conv2d(
channels, channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x: torch.Tensor):
"""
:param x: is the feature map of shape `[batch_size, channels, height, width]`
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
"""
# Get shape `[batch_size, channels, height, width]`
b, c, h, w = x.shape
# For residual connection
x_in = x
# Normalize
x = self.norm(x)
# Initial $1 \times 1$ convolution
x = self.proj_in(x)
# Transpose and reshape from `[batch_size, channels, height, width]`
# to `[batch_size, height * width, channels]`
x = x.permute(0, 2, 3, 1).view(b, h * w, c)
# Apply the transformer layers
for block in self.transformer_blocks:
x = block(x)
# Reshape and transpose from `[batch_size, height * width, channels]`
# to `[batch_size, channels, height, width]`
x = x.view(b, h, w, c).permute(0, 3, 1, 2)
# Final $1 \times 1$ convolution
x = self.proj_out(x)
# Add residual
return x + x_in
class BasicTransformerBlock(nn.Module):
"""
### Transformer Layer
"""
def __init__(self, d_model: int, n_heads: int, d_head: int):
"""
:param d_model: is the input embedding size
:param n_heads: is the number of attention heads
:param d_head: is the size of a attention head
:param d_cond: is the size of the conditional embeddings
"""
super().__init__()
# Self-attention layer and pre-norm layer
self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head)
self.norm1 = nn.LayerNorm(d_model)
# Cross attention layer and pre-norm layer
#self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head)
self.norm2 = nn.LayerNorm(d_model)
# Feed-forward network and pre-norm layer
self.ff = FeedForward(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor):
"""
:param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
"""
# Self attention
x = self.attn1(self.norm1(x)) + x
# Cross-attention with conditioning
# x = self.attn2(self.norm2(x), cond=cond) + x
# Feed-forward network
x = self.ff(self.norm3(x)) + x
#
return x
class CrossAttention(nn.Module):
"""
### Cross Attention Layer
This falls-back to self-attention when conditional embeddings are not specified.
"""
use_flash_attention: bool = False
def __init__(
self,
d_model: int,
d_cond: int,
n_heads: int,
d_head: int,
is_inplace: bool = True
):
"""
:param d_model: is the input embedding size
:param n_heads: is the number of attention heads
:param d_head: is the size of a attention head
:param d_cond: is the size of the conditional embeddings
:param is_inplace: specifies whether to perform the attention softmax computation inplace to
save memory
"""
super().__init__()
self.is_inplace = is_inplace
self.n_heads = n_heads
self.d_head = d_head
# Attention scaling factor
self.scale = d_head**-0.5
# Query, key and value mappings
d_attn = d_head * n_heads
self.to_q = nn.Linear(d_model, d_attn, bias=False)
self.to_k = nn.Linear(d_cond, d_attn, bias=False)
self.to_v = nn.Linear(d_cond, d_attn, bias=False)
# Final linear layer
self.to_out = nn.Sequential(nn.Linear(d_attn, d_model))
# Setup [flash attention](https://github.com/HazyResearch/flash-attention).
# Flash attention is only used if it's installed
# and `CrossAttention.use_flash_attention` is set to `True`.
try:
# You can install flash attention by cloning their Github repo,
# [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
# and then running `python setup.py install`
from flash_attn.flash_attention import FlashAttention
self.flash = FlashAttention()
# Set the scale for scaled dot-product attention.
self.flash.softmax_scale = self.scale
# Set to `None` if it's not installed
except ImportError:
self.flash = None
def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):
"""
:param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
"""
# If `cond` is `None` we perform self attention
has_cond = cond is not None
if not has_cond:
cond = x
# Get query, key and value vectors
q = self.to_q(x)
k = self.to_k(cond)
v = self.to_v(cond)
# Use flash attention if it's available and the head size is less than or equal to `128`
if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128:
return self.flash_attention(q, k, v)
# Otherwise, fallback to normal attention
else:
return self.normal_attention(q, k, v)
def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""
#### Flash Attention
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
"""
# Get batch size and number of elements along sequence axis (`width * height`)
batch_size, seq_len, _ = q.shape
# Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
# shape `[batch_size, seq_len, 3, n_heads * d_head]`
qkv = torch.stack((q, k, v), dim=2)
# Split the heads
qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
# Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
# fit this size.
if self.d_head <= 32:
pad = 32 - self.d_head
elif self.d_head <= 64:
pad = 64 - self.d_head
elif self.d_head <= 128:
pad = 128 - self.d_head
else:
raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')
# Pad the heads
if pad:
qkv = torch.cat(
(qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
)
# Compute attention
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
# This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
out, _ = self.flash(qkv)
# Truncate the extra head size
out = out[:, :, :, : self.d_head]
# Reshape to `[batch_size, seq_len, n_heads * d_head]`
out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
# Map to `[batch_size, height * width, d_model]` with a linear layer
return self.to_out(out)
def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""
#### Normal Attention
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
"""
# Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
q = q.view(*q.shape[: 2], self.n_heads, -1)
k = k.view(*k.shape[: 2], self.n_heads, -1)
v = v.view(*v.shape[: 2], self.n_heads, -1)
# Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
# Compute softmax
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
if self.is_inplace:
half = attn.shape[0] // 2
attn[half :] = attn[half :].softmax(dim=-1)
attn[: half] = attn[: half].softmax(dim=-1)
else:
attn = attn.softmax(dim=-1)
# Compute attention output
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
out = torch.einsum('bhij,bjhd->bihd', attn, v)
# Reshape to `[batch_size, height * width, n_heads * d_head]`
out = out.reshape(*out.shape[: 2], -1)
# Map to `[batch_size, height * width, d_model]` with a linear layer
return self.to_out(out)
class FeedForward(nn.Module):
"""
### Feed-Forward Network
"""
def __init__(self, d_model: int, d_mult: int = 4):
"""
:param d_model: is the input embedding size
:param d_mult: is multiplicative factor for the hidden layer size
"""
super().__init__()
self.net = nn.Sequential(
GeGLU(d_model, d_model * d_mult), nn.Dropout(0.),
nn.Linear(d_model * d_mult, d_model)
)
def forward(self, x: torch.Tensor):
return self.net(x)
class GeGLU(nn.Module):
"""
### GeGLU Activation
$$\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$$
"""
def __init__(self, d_in: int, d_out: int):
super().__init__()
# Combined linear projections $xW + b$ and $xV + c$
self.proj = nn.Linear(d_in, d_out * 2)
def forward(self, x: torch.Tensor):
# Get $xW + b$ and $xV + c$
x, gate = self.proj(x).chunk(2, dim=-1)
# $\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$
return x * F.gelu(gate)
|