|
import math
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from math import log2
|
|
from torch import nn, einsum
|
|
from kornia.filters import filter2d
|
|
from einops import reduce, rearrange, repeat
|
|
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
|
|
def is_power_of_two(val):
|
|
return log2(val).is_integer()
|
|
|
|
|
|
def default(val, d):
|
|
return val if exists(val) else d
|
|
|
|
|
|
def get_1d_dct(i, freq, L):
|
|
result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L)
|
|
return result * (1 if freq == 0 else math.sqrt(2))
|
|
|
|
|
|
def get_dct_weights(width, channel, fidx_u, fidx_v):
|
|
dct_weights = torch.zeros(1, channel, width, width)
|
|
c_part = channel // len(fidx_u)
|
|
|
|
for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)):
|
|
for x in range(width):
|
|
for y in range(width):
|
|
coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width)
|
|
dct_weights[:, i * c_part : (i + 1) * c_part, x, y] = coor_value
|
|
|
|
return dct_weights
|
|
|
|
|
|
class Blur(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
f = torch.Tensor([1, 2, 1])
|
|
self.register_buffer("f", f)
|
|
|
|
def forward(self, x):
|
|
f = self.f
|
|
f = f[None, None, :] * f[None, :, None]
|
|
return filter2d(x, f, normalized=True)
|
|
|
|
|
|
class ChanNorm(nn.Module):
|
|
def __init__(self, dim, eps=1e-5):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
|
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
|
|
|
def forward(self, x):
|
|
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
|
|
mean = torch.mean(x, dim=1, keepdim=True)
|
|
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
|
|
|
|
|
def Conv2dSame(dim_in, dim_out, kernel_size, bias=True):
|
|
pad_left = kernel_size // 2
|
|
pad_right = (pad_left - 1) if (kernel_size % 2) == 0 else pad_left
|
|
|
|
return nn.Sequential(
|
|
nn.ZeroPad2d((pad_left, pad_right, pad_left, pad_right)),
|
|
nn.Conv2d(dim_in, dim_out, kernel_size, bias=bias),
|
|
)
|
|
|
|
|
|
class DepthWiseConv2d(nn.Module):
|
|
def __init__(self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Conv2d(
|
|
dim_in,
|
|
dim_in,
|
|
kernel_size=kernel_size,
|
|
padding=padding,
|
|
groups=dim_in,
|
|
stride=stride,
|
|
bias=bias,
|
|
),
|
|
nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
class FCANet(nn.Module):
|
|
def __init__(self, *, chan_in, chan_out, reduction=4, width):
|
|
super().__init__()
|
|
|
|
freq_w, freq_h = ([0] * 8), list(
|
|
range(8)
|
|
)
|
|
dct_weights = get_dct_weights(
|
|
width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w]
|
|
)
|
|
self.register_buffer("dct_weights", dct_weights)
|
|
|
|
chan_intermediate = max(3, chan_out // reduction)
|
|
|
|
self.net = nn.Sequential(
|
|
nn.Conv2d(chan_in, chan_intermediate, 1),
|
|
nn.LeakyReLU(0.1),
|
|
nn.Conv2d(chan_intermediate, chan_out, 1),
|
|
nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = reduce(
|
|
x * self.dct_weights, "b c (h h1) (w w1) -> b c h1 w1", "sum", h1=1, w1=1
|
|
)
|
|
return self.net(x)
|
|
|
|
|
|
class Generator(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
image_size,
|
|
latent_dim=256,
|
|
fmap_max=512,
|
|
fmap_inverse_coef=12,
|
|
transparent=False,
|
|
greyscale=False,
|
|
attn_res_layers=[],
|
|
freq_chan_attn=False,
|
|
syncbatchnorm=False,
|
|
antialias=False,
|
|
):
|
|
super().__init__()
|
|
resolution = log2(image_size)
|
|
assert is_power_of_two(image_size), "image size must be a power of 2"
|
|
|
|
|
|
norm_class = nn.SyncBatchNorm if syncbatchnorm else nn.BatchNorm2d
|
|
Blur = nn.Identity if not antialias else Blur
|
|
|
|
if transparent:
|
|
init_channel = 4
|
|
elif greyscale:
|
|
init_channel = 1
|
|
else:
|
|
init_channel = 3
|
|
|
|
self.latent_dim = latent_dim
|
|
|
|
fmap_max = default(fmap_max, latent_dim)
|
|
|
|
self.initial_conv = nn.Sequential(
|
|
nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4),
|
|
norm_class(latent_dim * 2),
|
|
nn.GLU(dim=1),
|
|
)
|
|
|
|
num_layers = int(resolution) - 2
|
|
features = list(
|
|
map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2))
|
|
)
|
|
features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
|
|
features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))
|
|
features = [latent_dim, *features]
|
|
|
|
in_out_features = list(zip(features[:-1], features[1:]))
|
|
|
|
self.res_layers = range(2, num_layers + 2)
|
|
self.layers = nn.ModuleList([])
|
|
self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))
|
|
|
|
self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10))
|
|
self.sle_map = list(
|
|
filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)
|
|
)
|
|
self.sle_map = dict(self.sle_map)
|
|
|
|
self.num_layers_spatial_res = 1
|
|
|
|
for res, (chan_in, chan_out) in zip(self.res_layers, in_out_features):
|
|
image_width = 2**res
|
|
|
|
attn = None
|
|
if image_width in attn_res_layers:
|
|
attn = PreNorm(chan_in, LinearAttention(chan_in))
|
|
|
|
sle = None
|
|
if res in self.sle_map:
|
|
residual_layer = self.sle_map[res]
|
|
sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]
|
|
|
|
if freq_chan_attn:
|
|
sle = FCANet(
|
|
chan_in=chan_out, chan_out=sle_chan_out, width=2 ** (res + 1)
|
|
)
|
|
else:
|
|
sle = GlobalContext(chan_in=chan_out, chan_out=sle_chan_out)
|
|
|
|
layer = nn.ModuleList(
|
|
[
|
|
nn.Sequential(
|
|
PixelShuffleUpsample(chan_in),
|
|
Blur(),
|
|
Conv2dSame(chan_in, chan_out * 2, 4),
|
|
Noise(),
|
|
norm_class(chan_out * 2),
|
|
nn.GLU(dim=1),
|
|
),
|
|
sle,
|
|
attn,
|
|
]
|
|
)
|
|
self.layers.append(layer)
|
|
|
|
self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1)
|
|
|
|
def forward(self, x):
|
|
x = rearrange(x, "b c -> b c () ()")
|
|
x = self.initial_conv(x)
|
|
x = F.normalize(x, dim=1)
|
|
|
|
residuals = dict()
|
|
|
|
for res, (up, sle, attn) in zip(self.res_layers, self.layers):
|
|
if exists(attn):
|
|
x = attn(x) + x
|
|
|
|
x = up(x)
|
|
|
|
if exists(sle):
|
|
out_res = self.sle_map[res]
|
|
residual = sle(x)
|
|
residuals[out_res] = residual
|
|
|
|
next_res = res + 1
|
|
if next_res in residuals:
|
|
x = x * residuals[next_res]
|
|
|
|
return self.out_conv(x)
|
|
|
|
|
|
class GlobalContext(nn.Module):
|
|
def __init__(self, *, chan_in, chan_out):
|
|
super().__init__()
|
|
self.to_k = nn.Conv2d(chan_in, 1, 1)
|
|
chan_intermediate = max(3, chan_out // 2)
|
|
|
|
self.net = nn.Sequential(
|
|
nn.Conv2d(chan_in, chan_intermediate, 1),
|
|
nn.LeakyReLU(0.1),
|
|
nn.Conv2d(chan_intermediate, chan_out, 1),
|
|
nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
context = self.to_k(x)
|
|
context = context.flatten(2).softmax(dim=-1)
|
|
out = einsum("b i n, b c n -> b c i", context, x.flatten(2))
|
|
out = out.unsqueeze(-1)
|
|
return self.net(out)
|
|
|
|
|
|
class LinearAttention(nn.Module):
|
|
def __init__(self, dim, dim_head=64, heads=8, kernel_size=3):
|
|
super().__init__()
|
|
self.scale = dim_head**-0.5
|
|
self.heads = heads
|
|
self.dim_head = dim_head
|
|
inner_dim = dim_head * heads
|
|
|
|
self.kernel_size = kernel_size
|
|
self.nonlin = nn.GELU()
|
|
|
|
self.to_lin_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
|
|
self.to_lin_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding=1, bias=False)
|
|
|
|
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
|
|
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
|
|
|
|
self.to_out = nn.Conv2d(inner_dim * 2, dim, 1)
|
|
|
|
def forward(self, fmap):
|
|
h, x, y = self.heads, *fmap.shape[-2:]
|
|
|
|
|
|
|
|
lin_q, lin_k, lin_v = (
|
|
self.to_lin_q(fmap),
|
|
*self.to_lin_kv(fmap).chunk(2, dim=1),
|
|
)
|
|
lin_q, lin_k, lin_v = map(
|
|
lambda t: rearrange(t, "b (h c) x y -> (b h) (x y) c", h=h),
|
|
(lin_q, lin_k, lin_v),
|
|
)
|
|
|
|
lin_q = lin_q.softmax(dim=-1)
|
|
lin_k = lin_k.softmax(dim=-2)
|
|
|
|
lin_q = lin_q * self.scale
|
|
|
|
context = einsum("b n d, b n e -> b d e", lin_k, lin_v)
|
|
lin_out = einsum("b n d, b d e -> b n e", lin_q, context)
|
|
lin_out = rearrange(lin_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)
|
|
|
|
|
|
|
|
q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim=1))
|
|
q, k, v = map(
|
|
lambda t: rearrange(t, "b (h c) x y -> (b h) c x y", h=h), (q, k, v)
|
|
)
|
|
|
|
k = F.unfold(k, kernel_size=self.kernel_size, padding=self.kernel_size // 2)
|
|
v = F.unfold(v, kernel_size=self.kernel_size, padding=self.kernel_size // 2)
|
|
|
|
k, v = map(
|
|
lambda t: rearrange(t, "b (d j) n -> b n j d", d=self.dim_head), (k, v)
|
|
)
|
|
|
|
q = rearrange(q, "b c ... -> b (...) c") * self.scale
|
|
|
|
sim = einsum("b i d, b i j d -> b i j", q, k)
|
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
|
|
|
attn = sim.softmax(dim=-1)
|
|
|
|
full_out = einsum("b i j, b i j d -> b i d", attn, v)
|
|
full_out = rearrange(full_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)
|
|
|
|
|
|
|
|
lin_out = self.nonlin(lin_out)
|
|
out = torch.cat((lin_out, full_out), dim=1)
|
|
return self.to_out(out)
|
|
|
|
|
|
class Noise(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.zeros(1))
|
|
|
|
def forward(self, x, noise=None):
|
|
b, _, h, w, device = *x.shape, x.device
|
|
|
|
if not exists(noise):
|
|
noise = torch.randn(b, 1, h, w, device=device)
|
|
|
|
return x + self.weight * noise
|
|
|
|
|
|
class PixelShuffleUpsample(nn.Module):
|
|
def __init__(self, dim, dim_out=None):
|
|
super().__init__()
|
|
dim_out = default(dim_out, dim)
|
|
conv = nn.Conv2d(dim, dim_out * 4, 1)
|
|
|
|
self.net = nn.Sequential(conv, nn.SiLU(), nn.PixelShuffle(2))
|
|
|
|
self.init_conv_(conv)
|
|
|
|
def init_conv_(self, conv):
|
|
o, i, h, w = conv.weight.shape
|
|
conv_weight = torch.empty(o // 4, i, h, w)
|
|
nn.init.kaiming_uniform_(conv_weight)
|
|
conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
|
|
|
|
conv.weight.data.copy_(conv_weight)
|
|
nn.init.zeros_(conv.bias.data)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
class PreNorm(nn.Module):
|
|
def __init__(self, dim, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
self.norm = ChanNorm(dim)
|
|
|
|
def forward(self, x):
|
|
return self.fn(self.norm(x))
|
|
|