|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from math import log2
|
|
from einops import rearrange
|
|
|
|
from .Blur import Blur
|
|
from .Noise import Noise
|
|
from .FCANet import FCANet
|
|
from .PreNorm import PreNorm
|
|
from .Conv2dSame import Conv2dSame
|
|
from .GlobalContext import GlobalContext
|
|
from .LinearAttention import LinearAttention
|
|
from .PixelShuffleUpsample import PixelShuffleUpsample
|
|
from .helper_funcs import exists, is_power_of_two, default
|
|
|
|
|
|
class Generator(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
image_size,
|
|
latent_dim=256,
|
|
fmap_max=512,
|
|
fmap_inverse_coef=12,
|
|
transparent=False,
|
|
greyscale=False,
|
|
attn_res_layers=[],
|
|
freq_chan_attn=False,
|
|
syncbatchnorm=False,
|
|
antialias=False,
|
|
):
|
|
super().__init__()
|
|
resolution = log2(image_size)
|
|
assert is_power_of_two(image_size), "image size must be a power of 2"
|
|
|
|
|
|
norm_class = nn.SyncBatchNorm if syncbatchnorm else nn.BatchNorm2d
|
|
Blur = nn.Identity if not antialias else Blur
|
|
|
|
if transparent:
|
|
init_channel = 4
|
|
elif greyscale:
|
|
init_channel = 1
|
|
else:
|
|
init_channel = 3
|
|
|
|
self.latent_dim = latent_dim
|
|
|
|
fmap_max = default(fmap_max, latent_dim)
|
|
|
|
self.initial_conv = nn.Sequential(
|
|
nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4),
|
|
norm_class(latent_dim * 2),
|
|
nn.GLU(dim=1),
|
|
)
|
|
|
|
num_layers = int(resolution) - 2
|
|
features = list(
|
|
map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2))
|
|
)
|
|
features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
|
|
features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))
|
|
features = [latent_dim, *features]
|
|
|
|
in_out_features = list(zip(features[:-1], features[1:]))
|
|
|
|
self.res_layers = range(2, num_layers + 2)
|
|
self.layers = nn.ModuleList([])
|
|
self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))
|
|
|
|
self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10))
|
|
self.sle_map = list(
|
|
filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)
|
|
)
|
|
self.sle_map = dict(self.sle_map)
|
|
|
|
self.num_layers_spatial_res = 1
|
|
|
|
for res, (chan_in, chan_out) in zip(self.res_layers, in_out_features):
|
|
image_width = 2**res
|
|
|
|
attn = None
|
|
if image_width in attn_res_layers:
|
|
attn = PreNorm(chan_in, LinearAttention(chan_in))
|
|
|
|
sle = None
|
|
if res in self.sle_map:
|
|
residual_layer = self.sle_map[res]
|
|
sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]
|
|
|
|
if freq_chan_attn:
|
|
sle = FCANet(
|
|
chan_in=chan_out, chan_out=sle_chan_out, width=2 ** (res + 1)
|
|
)
|
|
else:
|
|
sle = GlobalContext(chan_in=chan_out, chan_out=sle_chan_out)
|
|
|
|
layer = nn.ModuleList(
|
|
[
|
|
nn.Sequential(
|
|
PixelShuffleUpsample(chan_in),
|
|
Blur(),
|
|
Conv2dSame(chan_in, chan_out * 2, 4),
|
|
Noise(),
|
|
norm_class(chan_out * 2),
|
|
nn.GLU(dim=1),
|
|
),
|
|
sle,
|
|
attn,
|
|
]
|
|
)
|
|
self.layers.append(layer)
|
|
|
|
self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1)
|
|
|
|
def forward(self, x):
|
|
x = rearrange(x, "b c -> b c () ()")
|
|
x = self.initial_conv(x)
|
|
x = F.normalize(x, dim=1)
|
|
|
|
residuals = dict()
|
|
|
|
for res, (up, sle, attn) in zip(self.res_layers, self.layers):
|
|
if exists(attn):
|
|
x = attn(x) + x
|
|
|
|
x = up(x)
|
|
|
|
if exists(sle):
|
|
out_res = self.sle_map[res]
|
|
residual = sle(x)
|
|
residuals[out_res] = residual
|
|
|
|
next_res = res + 1
|
|
if next_res in residuals:
|
|
x = x * residuals[next_res]
|
|
|
|
return self.out_conv(x)
|
|
|