TEDM-demo / models /unet_model.py
anonymous
first commit without models
a2dba58
raw
history blame
11.5 kB
"""Adapted from https://github.com/lucidrains/denoising-diffusion-pytorch"""
import math
from collections import namedtuple
from functools import partial
from typing import List, Optional
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import einsum, nn, Tensor
from trainers.utils import default, exists
# constants
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
# helpers functions
def l2norm(t: Tensor) -> Tensor:
"""L2 normalize along last dimension"""
return F.normalize(t, dim=-1)
# small helper modules
class Residual(nn.Module):
"""Residual of any Module -> x' = f(x) + x"""
def __init__(self, fn: nn.Module):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
def Upsample(dim: int, dim_out: Optional[int] = None) -> nn.Sequential:
"""UpsampleConv with factor 2"""
return nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(dim, default(dim_out, dim), 3, padding=1)
)
def Downsample(dim: int, dim_out: Optional[int] = None) -> nn.Conv2d:
"""Strided Conv2d for downsampling"""
return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)
class LayerNorm(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x: Tensor) -> Tensor:
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) * (var + eps).rsqrt() * self.g
class PreNorm(nn.Module):
"""Apply LayerNorm before any Module"""
def __init__(self, dim: int, fn: nn.Module):
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x: Tensor) -> Tensor:
x = self.norm(x)
return self.fn(x)
class SinusoidalPosEmb(nn.Module):
"""Classical sinosoidal embedding"""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, t: Tensor) -> Tensor:
"""
:param t: Batch of time steps (b,)
:return emb: Sinusoidal time embedding (b, dim)
"""
device = t.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = t[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class LearnedSinusoidalPosEmb(nn.Module):
""" following @crowsonkb 's lead with learned sinusoidal pos emb """
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, t: Tensor) -> Tensor:
"""
:param t: Batch of time steps (b,)
:return fouriered: Concatenation of t and time embedding (b, dim + 1)
"""
t = rearrange(t, 'b -> b 1')
freqs = t * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((t, fouriered), dim=-1)
return fouriered
# building block modules
class Block(nn.Module):
def __init__(self, dim: int, dim_out: int, groups: int = 8):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x: Tensor, scale_shift: Optional[Tensor] = None) -> Tensor:
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
def __init__(
self,
dim: int,
dim_out: int,
*,
time_emb_dim: Optional[int] = None,
groups: int = 8
):
super().__init__()
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
if dim != dim_out:
self.res_conv = nn.Conv2d(dim, dim_out, 1)
else:
self.res_conv = nn.Identity()
def forward(self, x: Tensor, time_emb: Optional[Tensor] = None) -> Tensor:
"""
:param x: Batch of input images (b, c, h, w)
:param time_emb: Batch of time embeddings (b, c)
"""
scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim=1)
h = self.block1(x, scale_shift=scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
class LinearAttention(nn.Module):
"""Attention with linear to_qtv"""
def __init__(self, dim: int, heads: int = 4, dim_head: int = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Sequential(
nn.Conv2d(hidden_dim, dim, 1),
LayerNorm(dim)
)
def forward(self, x: Tensor) -> Tensor:
"""
:param x: Batch of input images (b, c, h, w)
"""
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv)
q = q.softmax(dim=-2)
k = k.softmax(dim=-1)
q = q * self.scale
v = v / (h * w)
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w)
return self.to_out(out)
class Attention(nn.Module):
"""Attention with convolutional to_qtv"""
def __init__(
self,
dim: int,
heads: int = 4,
dim_head: int = 32,
scale: int = 16
):
super().__init__()
self.scale = scale
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x: Tensor) -> Tensor:
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv)
q, k = map(l2norm, (q, k))
sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
attn = sim.softmax(dim=-1)
out = einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
return self.to_out(out)
# model
class Unet(nn.Module):
def __init__(
self,
dim: int = 64,
init_dim: Optional[int] = None,
out_dim: Optional[int] = None,
dim_mults: List[int] = [1, 2, 4, 8],
channels: int = 1,
resnet_block_groups: int = 8,
learned_variance: bool = False,
learned_sinusoidal_cond: bool = False,
learned_sinusoidal_dim: int = 16,
**kwargs
):
super().__init__()
# determine dimensions
self.channels = channels
init_dim = default(init_dim, dim)
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
block_class = partial(ResnetBlock, groups=resnet_block_groups)
# time embeddings
time_dim = dim * 4
self.learned_sinusoidal_cond = learned_sinusoidal_cond
if learned_sinusoidal_cond:
sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
fourier_dim = learned_sinusoidal_dim + 1
else:
sinu_pos_emb = SinusoidalPosEmb(dim)
fourier_dim = dim
self.time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(fourier_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
block_class(dim_in, dim_in, time_emb_dim=time_dim),
block_class(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(
dim_in, dim_out, 3, padding=1)
]))
mid_dim = dims[-1]
self.mid_block1 = block_class(mid_dim, mid_dim, time_emb_dim=time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_class(mid_dim, mid_dim, time_emb_dim=time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind == (len(in_out) - 1)
self.ups.append(nn.ModuleList([
block_class(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
block_class(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(
dim_out, dim_in, 3, padding=1)
]))
default_out_dim = channels * (1 if not learned_variance else 2)
self.out_dim = default(out_dim, default_out_dim)
self.final_res_block = block_class(dim * 2, dim, time_emb_dim=time_dim)
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
def forward(self, x: Tensor, timestep: Optional[Tensor]=None, cond: Optional[Tensor]=None) -> Tensor:
x = self.init_conv(x)
r = x.clone()
t = self.time_mlp(timestep) if timestep is not None else None
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
h.append(x)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = torch.cat((x, h.pop()), dim=1)
x = block2(x, t)
x = attn(x)
x = upsample(x)
x = torch.cat((x, r), dim=1)
x = self.final_res_block(x, t)
return self.final_conv(x)
if __name__ == '__main__':
model = Unet(channels=1)
x = torch.randn(1, 1, 128, 128)
y = model(x, timestep=torch.tensor([100]))
print(y.shape)