michaelriedl's picture
Initial dump
002ca81
import torch.nn.functional as F
from torch import nn
from math import log2
from einops import rearrange
from .Blur import Blur
from .Noise import Noise
from .FCANet import FCANet
from .PreNorm import PreNorm
from .Conv2dSame import Conv2dSame
from .GlobalContext import GlobalContext
from .LinearAttention import LinearAttention
from .PixelShuffleUpsample import PixelShuffleUpsample
from .helper_funcs import exists, is_power_of_two, default
class Generator(nn.Module):
def __init__(
self,
*,
image_size,
latent_dim=256,
fmap_max=512,
fmap_inverse_coef=12,
transparent=False,
greyscale=False,
attn_res_layers=[],
freq_chan_attn=False,
syncbatchnorm=False,
antialias=False,
):
super().__init__()
resolution = log2(image_size)
assert is_power_of_two(image_size), "image size must be a power of 2"
# Set the normalization and blur
norm_class = nn.SyncBatchNorm if syncbatchnorm else nn.BatchNorm2d
Blur = nn.Identity if not antialias else Blur
if transparent:
init_channel = 4
elif greyscale:
init_channel = 1
else:
init_channel = 3
self.latent_dim = latent_dim
fmap_max = default(fmap_max, latent_dim)
self.initial_conv = nn.Sequential(
nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4),
norm_class(latent_dim * 2),
nn.GLU(dim=1),
)
num_layers = int(resolution) - 2
features = list(
map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2))
)
features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))
features = [latent_dim, *features]
in_out_features = list(zip(features[:-1], features[1:]))
self.res_layers = range(2, num_layers + 2)
self.layers = nn.ModuleList([])
self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))
self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10))
self.sle_map = list(
filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)
)
self.sle_map = dict(self.sle_map)
self.num_layers_spatial_res = 1
for res, (chan_in, chan_out) in zip(self.res_layers, in_out_features):
image_width = 2**res
attn = None
if image_width in attn_res_layers:
attn = PreNorm(chan_in, LinearAttention(chan_in))
sle = None
if res in self.sle_map:
residual_layer = self.sle_map[res]
sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]
if freq_chan_attn:
sle = FCANet(
chan_in=chan_out, chan_out=sle_chan_out, width=2 ** (res + 1)
)
else:
sle = GlobalContext(chan_in=chan_out, chan_out=sle_chan_out)
layer = nn.ModuleList(
[
nn.Sequential(
PixelShuffleUpsample(chan_in),
Blur(),
Conv2dSame(chan_in, chan_out * 2, 4),
Noise(),
norm_class(chan_out * 2),
nn.GLU(dim=1),
),
sle,
attn,
]
)
self.layers.append(layer)
self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1)
def forward(self, x):
x = rearrange(x, "b c -> b c () ()")
x = self.initial_conv(x)
x = F.normalize(x, dim=1)
residuals = dict()
for res, (up, sle, attn) in zip(self.res_layers, self.layers):
if exists(attn):
x = attn(x) + x
x = up(x)
if exists(sle):
out_res = self.sle_map[res]
residual = sle(x)
residuals[out_res] = residual
next_res = res + 1
if next_res in residuals:
x = x * residuals[next_res]
return self.out_conv(x)