diffusion_brain / models /aekl_no_attention.py
Warvito
commit message
c9cd3be
raw
history blame
No virus
6.93 kB
"""
AUTOENCODER WITH ARCHTECTURE FROM VERSION 2
"""
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
@torch.jit.script
def swish(x):
return x * torch.sigmoid(x)
def Normalize(in_channels):
return nn.GroupNorm(
num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True
)
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv3d(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1
)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv3d(
in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0
)
def forward(self, x):
pad = (0, 1, 0, 1, 0, 1)
x = nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = Normalize(in_channels)
self.conv1 = nn.Conv3d(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1
)
self.norm2 = Normalize(out_channels)
self.conv2 = nn.Conv3d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1
)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv3d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0
)
def forward(self, x):
h = x
h = self.norm1(h)
h = F.silu(h)
h = self.conv1(h)
h = self.norm2(h)
h = F.silu(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Encoder(nn.Module):
def __init__(
self,
in_channels: int,
n_channels: int,
z_channels: int,
ch_mult: Tuple[int],
num_res_blocks: int,
resolution: Tuple[int],
attn_resolutions: Tuple[int],
**ignorekwargs,
) -> None:
super().__init__()
self.in_channels = in_channels
self.n_channels = n_channels
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.attn_resolutions = attn_resolutions
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
blocks = []
# initial convolution
blocks.append(
nn.Conv3d(
in_channels,
n_channels,
kernel_size=3,
stride=1,
padding=1
)
)
# residual and downsampling blocks, with attention on smaller res (16x16)
for i in range(self.num_resolutions):
block_in_ch = n_channels * in_ch_mult[i]
block_out_ch = n_channels * ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if i != self.num_resolutions - 1:
blocks.append(Downsample(block_in_ch))
curr_res = tuple(ti // 2 for ti in curr_res)
# normalise and convert to latent size
blocks.append(Normalize(block_in_ch))
blocks.append(
nn.Conv3d(
block_in_ch,
z_channels,
kernel_size=3,
stride=1,
padding=1
)
)
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class Decoder(nn.Module):
def __init__(
self,
n_channels: int,
z_channels: int,
out_channels: int,
ch_mult: Tuple[int],
num_res_blocks: int,
resolution: Tuple[int],
attn_resolutions: Tuple[int],
**ignorekwargs,
) -> None:
super().__init__()
self.n_channels = n_channels
self.z_channels = z_channels
self.out_channels = out_channels
self.ch_mult = ch_mult
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.attn_resolutions = attn_resolutions
block_in_ch = n_channels * self.ch_mult[-1]
curr_res = tuple(ti // 2 ** (self.num_resolutions - 1) for ti in resolution)
blocks = []
# initial conv
blocks.append(
nn.Conv3d(
z_channels,
block_in_ch,
kernel_size=3,
stride=1,
padding=1
)
)
for i in reversed(range(self.num_resolutions)):
block_out_ch = n_channels * self.ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if i != 0:
blocks.append(Upsample(block_in_ch))
curr_res = tuple(ti * 2 for ti in curr_res)
blocks.append(Normalize(block_in_ch))
blocks.append(
nn.Conv3d(
block_in_ch,
out_channels,
kernel_size=3,
stride=1,
padding=1
)
)
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class AutoencoderKL(nn.Module):
def __init__(self, embed_dim: int, hparams) -> None:
super().__init__()
self.encoder = Encoder(**hparams)
self.decoder = Decoder(**hparams)
self.quant_conv_mu = torch.nn.Conv3d(hparams["z_channels"], embed_dim, 1)
self.quant_conv_log_sigma = torch.nn.Conv3d(hparams["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv3d(embed_dim, hparams["z_channels"], 1)
self.embed_dim = embed_dim
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def reconstruct_ldm_outputs(self, z):
x_hat = self.decode(z)
return x_hat