Spaces:
Runtime error
Runtime error
"""Adapted from https://github.com/lucidrains/denoising-diffusion-pytorch""" | |
import math | |
from collections import namedtuple | |
from functools import partial | |
from typing import List, Optional | |
import torch | |
import torch.nn.functional as F | |
from einops import rearrange | |
from torch import einsum, nn, Tensor | |
from trainers.utils import default, exists | |
# constants | |
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) | |
# helpers functions | |
def l2norm(t: Tensor) -> Tensor: | |
"""L2 normalize along last dimension""" | |
return F.normalize(t, dim=-1) | |
# small helper modules | |
class Residual(nn.Module): | |
"""Residual of any Module -> x' = f(x) + x""" | |
def __init__(self, fn: nn.Module): | |
super().__init__() | |
self.fn = fn | |
def forward(self, x, *args, **kwargs): | |
return self.fn(x, *args, **kwargs) + x | |
def Upsample(dim: int, dim_out: Optional[int] = None) -> nn.Sequential: | |
"""UpsampleConv with factor 2""" | |
return nn.Sequential( | |
nn.Upsample(scale_factor=2, mode='nearest'), | |
nn.Conv2d(dim, default(dim_out, dim), 3, padding=1) | |
) | |
def Downsample(dim: int, dim_out: Optional[int] = None) -> nn.Conv2d: | |
"""Strided Conv2d for downsampling""" | |
return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1) | |
class LayerNorm(nn.Module): | |
def __init__(self, dim: int): | |
super().__init__() | |
self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) | |
def forward(self, x: Tensor) -> Tensor: | |
eps = 1e-5 if x.dtype == torch.float32 else 1e-3 | |
var = torch.var(x, dim=1, unbiased=False, keepdim=True) | |
mean = torch.mean(x, dim=1, keepdim=True) | |
return (x - mean) * (var + eps).rsqrt() * self.g | |
class PreNorm(nn.Module): | |
"""Apply LayerNorm before any Module""" | |
def __init__(self, dim: int, fn: nn.Module): | |
super().__init__() | |
self.fn = fn | |
self.norm = LayerNorm(dim) | |
def forward(self, x: Tensor) -> Tensor: | |
x = self.norm(x) | |
return self.fn(x) | |
class SinusoidalPosEmb(nn.Module): | |
"""Classical sinosoidal embedding""" | |
def __init__(self, dim: int): | |
super().__init__() | |
self.dim = dim | |
def forward(self, t: Tensor) -> Tensor: | |
""" | |
:param t: Batch of time steps (b,) | |
:return emb: Sinusoidal time embedding (b, dim) | |
""" | |
device = t.device | |
half_dim = self.dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
emb = t[:, None] * emb[None, :] | |
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
return emb | |
class LearnedSinusoidalPosEmb(nn.Module): | |
""" following @crowsonkb 's lead with learned sinusoidal pos emb """ | |
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ | |
def __init__(self, dim: int): | |
super().__init__() | |
assert (dim % 2) == 0 | |
half_dim = dim // 2 | |
self.weights = nn.Parameter(torch.randn(half_dim)) | |
def forward(self, t: Tensor) -> Tensor: | |
""" | |
:param t: Batch of time steps (b,) | |
:return fouriered: Concatenation of t and time embedding (b, dim + 1) | |
""" | |
t = rearrange(t, 'b -> b 1') | |
freqs = t * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi | |
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) | |
fouriered = torch.cat((t, fouriered), dim=-1) | |
return fouriered | |
# building block modules | |
class Block(nn.Module): | |
def __init__(self, dim: int, dim_out: int, groups: int = 8): | |
super().__init__() | |
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) | |
self.norm = nn.GroupNorm(groups, dim_out) | |
self.act = nn.SiLU() | |
def forward(self, x: Tensor, scale_shift: Optional[Tensor] = None) -> Tensor: | |
x = self.proj(x) | |
x = self.norm(x) | |
if exists(scale_shift): | |
scale, shift = scale_shift | |
x = x * (scale + 1) + shift | |
x = self.act(x) | |
return x | |
class ResnetBlock(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
dim_out: int, | |
*, | |
time_emb_dim: Optional[int] = None, | |
groups: int = 8 | |
): | |
super().__init__() | |
self.time_mlp = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(time_emb_dim, dim_out * 2) | |
) if exists(time_emb_dim) else None | |
self.block1 = Block(dim, dim_out, groups=groups) | |
self.block2 = Block(dim_out, dim_out, groups=groups) | |
if dim != dim_out: | |
self.res_conv = nn.Conv2d(dim, dim_out, 1) | |
else: | |
self.res_conv = nn.Identity() | |
def forward(self, x: Tensor, time_emb: Optional[Tensor] = None) -> Tensor: | |
""" | |
:param x: Batch of input images (b, c, h, w) | |
:param time_emb: Batch of time embeddings (b, c) | |
""" | |
scale_shift = None | |
if exists(self.time_mlp) and exists(time_emb): | |
time_emb = self.time_mlp(time_emb) | |
time_emb = rearrange(time_emb, 'b c -> b c 1 1') | |
scale_shift = time_emb.chunk(2, dim=1) | |
h = self.block1(x, scale_shift=scale_shift) | |
h = self.block2(h) | |
return h + self.res_conv(x) | |
class LinearAttention(nn.Module): | |
"""Attention with linear to_qtv""" | |
def __init__(self, dim: int, heads: int = 4, dim_head: int = 32): | |
super().__init__() | |
self.scale = dim_head ** -0.5 | |
self.heads = heads | |
hidden_dim = dim_head * heads | |
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) | |
self.to_out = nn.Sequential( | |
nn.Conv2d(hidden_dim, dim, 1), | |
LayerNorm(dim) | |
) | |
def forward(self, x: Tensor) -> Tensor: | |
""" | |
:param x: Batch of input images (b, c, h, w) | |
""" | |
b, c, h, w = x.shape | |
qkv = self.to_qkv(x).chunk(3, dim=1) | |
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv) | |
q = q.softmax(dim=-2) | |
k = k.softmax(dim=-1) | |
q = q * self.scale | |
v = v / (h * w) | |
context = torch.einsum('b h d n, b h e n -> b h d e', k, v) | |
out = torch.einsum('b h d e, b h d n -> b h e n', context, q) | |
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w) | |
return self.to_out(out) | |
class Attention(nn.Module): | |
"""Attention with convolutional to_qtv""" | |
def __init__( | |
self, | |
dim: int, | |
heads: int = 4, | |
dim_head: int = 32, | |
scale: int = 16 | |
): | |
super().__init__() | |
self.scale = scale | |
self.heads = heads | |
hidden_dim = dim_head * heads | |
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) | |
self.to_out = nn.Conv2d(hidden_dim, dim, 1) | |
def forward(self, x: Tensor) -> Tensor: | |
b, c, h, w = x.shape | |
qkv = self.to_qkv(x).chunk(3, dim=1) | |
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv) | |
q, k = map(l2norm, (q, k)) | |
sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale | |
attn = sim.softmax(dim=-1) | |
out = einsum('b h i j, b h d j -> b h i d', attn, v) | |
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) | |
return self.to_out(out) | |
# model | |
class Unet(nn.Module): | |
def __init__( | |
self, | |
dim: int = 64, | |
init_dim: Optional[int] = None, | |
out_dim: Optional[int] = None, | |
dim_mults: List[int] = [1, 2, 4, 8], | |
channels: int = 1, | |
resnet_block_groups: int = 8, | |
learned_variance: bool = False, | |
learned_sinusoidal_cond: bool = False, | |
learned_sinusoidal_dim: int = 16, | |
**kwargs | |
): | |
super().__init__() | |
# determine dimensions | |
self.channels = channels | |
init_dim = default(init_dim, dim) | |
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3) | |
dims = [init_dim, *map(lambda m: dim * m, dim_mults)] | |
in_out = list(zip(dims[:-1], dims[1:])) | |
block_class = partial(ResnetBlock, groups=resnet_block_groups) | |
# time embeddings | |
time_dim = dim * 4 | |
self.learned_sinusoidal_cond = learned_sinusoidal_cond | |
if learned_sinusoidal_cond: | |
sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim) | |
fourier_dim = learned_sinusoidal_dim + 1 | |
else: | |
sinu_pos_emb = SinusoidalPosEmb(dim) | |
fourier_dim = dim | |
self.time_mlp = nn.Sequential( | |
sinu_pos_emb, | |
nn.Linear(fourier_dim, time_dim), | |
nn.GELU(), | |
nn.Linear(time_dim, time_dim) | |
) | |
# layers | |
self.downs = nn.ModuleList([]) | |
self.ups = nn.ModuleList([]) | |
num_resolutions = len(in_out) | |
for ind, (dim_in, dim_out) in enumerate(in_out): | |
is_last = ind >= (num_resolutions - 1) | |
self.downs.append(nn.ModuleList([ | |
block_class(dim_in, dim_in, time_emb_dim=time_dim), | |
block_class(dim_in, dim_in, time_emb_dim=time_dim), | |
Residual(PreNorm(dim_in, LinearAttention(dim_in))), | |
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d( | |
dim_in, dim_out, 3, padding=1) | |
])) | |
mid_dim = dims[-1] | |
self.mid_block1 = block_class(mid_dim, mid_dim, time_emb_dim=time_dim) | |
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) | |
self.mid_block2 = block_class(mid_dim, mid_dim, time_emb_dim=time_dim) | |
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): | |
is_last = ind == (len(in_out) - 1) | |
self.ups.append(nn.ModuleList([ | |
block_class(dim_out + dim_in, dim_out, time_emb_dim=time_dim), | |
block_class(dim_out + dim_in, dim_out, time_emb_dim=time_dim), | |
Residual(PreNorm(dim_out, LinearAttention(dim_out))), | |
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d( | |
dim_out, dim_in, 3, padding=1) | |
])) | |
default_out_dim = channels * (1 if not learned_variance else 2) | |
self.out_dim = default(out_dim, default_out_dim) | |
self.final_res_block = block_class(dim * 2, dim, time_emb_dim=time_dim) | |
self.final_conv = nn.Conv2d(dim, self.out_dim, 1) | |
def forward(self, x: Tensor, timestep: Optional[Tensor]=None, cond: Optional[Tensor]=None) -> Tensor: | |
x = self.init_conv(x) | |
r = x.clone() | |
t = self.time_mlp(timestep) if timestep is not None else None | |
h = [] | |
for block1, block2, attn, downsample in self.downs: | |
x = block1(x, t) | |
h.append(x) | |
x = block2(x, t) | |
x = attn(x) | |
h.append(x) | |
x = downsample(x) | |
x = self.mid_block1(x, t) | |
x = self.mid_attn(x) | |
x = self.mid_block2(x, t) | |
for block1, block2, attn, upsample in self.ups: | |
x = torch.cat((x, h.pop()), dim=1) | |
x = block1(x, t) | |
x = torch.cat((x, h.pop()), dim=1) | |
x = block2(x, t) | |
x = attn(x) | |
x = upsample(x) | |
x = torch.cat((x, r), dim=1) | |
x = self.final_res_block(x, t) | |
return self.final_conv(x) | |
if __name__ == '__main__': | |
model = Unet(channels=1) | |
x = torch.randn(1, 1, 128, 128) | |
y = model(x, timestep=torch.tensor([100])) | |
print(y.shape) | |