File size: 4,446 Bytes
002ca81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch.nn.functional as F
from torch import nn
from math import log2
from einops import rearrange
from .Blur import Blur
from .Noise import Noise
from .FCANet import FCANet
from .PreNorm import PreNorm
from .Conv2dSame import Conv2dSame
from .GlobalContext import GlobalContext
from .LinearAttention import LinearAttention
from .PixelShuffleUpsample import PixelShuffleUpsample
from .helper_funcs import exists, is_power_of_two, default
class Generator(nn.Module):
def __init__(
self,
*,
image_size,
latent_dim=256,
fmap_max=512,
fmap_inverse_coef=12,
transparent=False,
greyscale=False,
attn_res_layers=[],
freq_chan_attn=False,
syncbatchnorm=False,
antialias=False,
):
super().__init__()
resolution = log2(image_size)
assert is_power_of_two(image_size), "image size must be a power of 2"
# Set the normalization and blur
norm_class = nn.SyncBatchNorm if syncbatchnorm else nn.BatchNorm2d
Blur = nn.Identity if not antialias else Blur
if transparent:
init_channel = 4
elif greyscale:
init_channel = 1
else:
init_channel = 3
self.latent_dim = latent_dim
fmap_max = default(fmap_max, latent_dim)
self.initial_conv = nn.Sequential(
nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4),
norm_class(latent_dim * 2),
nn.GLU(dim=1),
)
num_layers = int(resolution) - 2
features = list(
map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2))
)
features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))
features = [latent_dim, *features]
in_out_features = list(zip(features[:-1], features[1:]))
self.res_layers = range(2, num_layers + 2)
self.layers = nn.ModuleList([])
self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))
self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10))
self.sle_map = list(
filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)
)
self.sle_map = dict(self.sle_map)
self.num_layers_spatial_res = 1
for res, (chan_in, chan_out) in zip(self.res_layers, in_out_features):
image_width = 2**res
attn = None
if image_width in attn_res_layers:
attn = PreNorm(chan_in, LinearAttention(chan_in))
sle = None
if res in self.sle_map:
residual_layer = self.sle_map[res]
sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]
if freq_chan_attn:
sle = FCANet(
chan_in=chan_out, chan_out=sle_chan_out, width=2 ** (res + 1)
)
else:
sle = GlobalContext(chan_in=chan_out, chan_out=sle_chan_out)
layer = nn.ModuleList(
[
nn.Sequential(
PixelShuffleUpsample(chan_in),
Blur(),
Conv2dSame(chan_in, chan_out * 2, 4),
Noise(),
norm_class(chan_out * 2),
nn.GLU(dim=1),
),
sle,
attn,
]
)
self.layers.append(layer)
self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1)
def forward(self, x):
x = rearrange(x, "b c -> b c () ()")
x = self.initial_conv(x)
x = F.normalize(x, dim=1)
residuals = dict()
for res, (up, sle, attn) in zip(self.res_layers, self.layers):
if exists(attn):
x = attn(x) + x
x = up(x)
if exists(sle):
out_res = self.sle_map[res]
residual = sle(x)
residuals[out_res] = residual
next_res = res + 1
if next_res in residuals:
x = x * residuals[next_res]
return self.out_conv(x)
|