|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.modules.batchnorm import BatchNorm2d |
|
from torch.nn.utils import spectral_norm |
|
|
|
|
|
class SpectralConv2d(nn.Module): |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__() |
|
self._conv = spectral_norm( |
|
nn.Conv2d(*args, **kwargs) |
|
) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
return self._conv(input) |
|
|
|
|
|
class SpectralConvTranspose2d(nn.Module): |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__() |
|
self._conv = spectral_norm( |
|
nn.ConvTranspose2d(*args, **kwargs) |
|
) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
return self._conv(input) |
|
|
|
|
|
class Noise(nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self._weight = nn.Parameter( |
|
torch.zeros(1), |
|
requires_grad=True, |
|
) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
batch_size, _, height, width = input.shape |
|
noise = torch.randn(batch_size, 1, height, width, device=input.device) |
|
return self._weight * noise + input |
|
|
|
|
|
class InitLayer(nn.Module): |
|
|
|
def __init__(self, in_channels: int, |
|
out_channels: int): |
|
super().__init__() |
|
|
|
self._layers = nn.Sequential( |
|
SpectralConvTranspose2d( |
|
in_channels=in_channels, |
|
out_channels=out_channels * 2, |
|
kernel_size=4, |
|
stride=1, |
|
padding=0, |
|
bias=False, |
|
), |
|
nn.BatchNorm2d(num_features=out_channels * 2), |
|
nn.GLU(dim=1), |
|
) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
return self._layers(input) |
|
|
|
|
|
class SLEBlock(nn.Module): |
|
|
|
def __init__(self, in_channels: int, |
|
out_channels: int): |
|
super().__init__() |
|
|
|
self._layers = nn.Sequential( |
|
nn.AdaptiveAvgPool2d(output_size=4), |
|
SpectralConv2d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=4, |
|
stride=1, |
|
padding=0, |
|
bias=False, |
|
), |
|
nn.SiLU(), |
|
SpectralConv2d( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias=False, |
|
), |
|
nn.Sigmoid(), |
|
) |
|
|
|
def forward(self, low_dim: torch.Tensor, |
|
high_dim: torch.Tensor) -> torch.Tensor: |
|
return high_dim * self._layers(low_dim) |
|
|
|
|
|
class UpsampleBlockT1(nn.Module): |
|
|
|
def __init__(self, in_channels: int, |
|
out_channels: int): |
|
super().__init__() |
|
|
|
self._layers = nn.Sequential( |
|
nn.Upsample(scale_factor=2, mode='nearest'), |
|
SpectralConv2d( |
|
in_channels=in_channels, |
|
out_channels=out_channels * 2, |
|
kernel_size=3, |
|
stride=1, |
|
padding='same', |
|
bias=False, |
|
), |
|
nn.BatchNorm2d(num_features=out_channels * 2), |
|
nn.GLU(dim=1), |
|
) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
return self._layers(input) |
|
|
|
|
|
class UpsampleBlockT2(nn.Module): |
|
|
|
def __init__(self, in_channels: int, |
|
out_channels: int): |
|
super().__init__() |
|
|
|
self._layers = nn.Sequential( |
|
nn.Upsample(scale_factor=2, mode='nearest'), |
|
SpectralConv2d( |
|
in_channels=in_channels, |
|
out_channels=out_channels * 2, |
|
kernel_size=3, |
|
stride=1, |
|
padding='same', |
|
bias=False, |
|
), |
|
Noise(), |
|
BatchNorm2d(num_features=out_channels * 2), |
|
nn.GLU(dim=1), |
|
SpectralConv2d( |
|
in_channels=out_channels, |
|
out_channels=out_channels * 2, |
|
kernel_size=3, |
|
stride=1, |
|
padding='same', |
|
bias=False, |
|
), |
|
Noise(), |
|
nn.BatchNorm2d(num_features=out_channels * 2), |
|
nn.GLU(dim=1), |
|
) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
return self._layers(input) |
|
|
|
|
|
class DownsampleBlockT1(nn.Module): |
|
|
|
def __init__(self, in_channels: int, |
|
out_channels: int): |
|
super().__init__() |
|
|
|
self._layers = nn.Sequential( |
|
SpectralConv2d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=4, |
|
stride=2, |
|
padding=1, |
|
bias=False, |
|
), |
|
nn.BatchNorm2d(num_features=out_channels), |
|
nn.LeakyReLU(negative_slope=0.2), |
|
) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
return self._layers(input) |
|
|
|
|
|
class DownsampleBlockT2(nn.Module): |
|
|
|
def __init__(self, in_channels: int, |
|
out_channels: int): |
|
super().__init__() |
|
|
|
self._layers_1 = nn.Sequential( |
|
SpectralConv2d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=4, |
|
stride=2, |
|
padding=1, |
|
bias=False, |
|
), |
|
nn.BatchNorm2d(num_features=out_channels), |
|
nn.LeakyReLU(negative_slope=0.2), |
|
SpectralConv2d( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding='same', |
|
bias=False, |
|
), |
|
nn.BatchNorm2d(num_features=out_channels), |
|
nn.LeakyReLU(negative_slope=0.2), |
|
) |
|
|
|
self._layers_2 = nn.Sequential( |
|
nn.AvgPool2d( |
|
kernel_size=2, |
|
stride=2, |
|
), |
|
SpectralConv2d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias=False, |
|
), |
|
nn.BatchNorm2d(num_features=out_channels), |
|
nn.LeakyReLU(negative_slope=0.2), |
|
) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
t1 = self._layers_1(input) |
|
t2 = self._layers_2(input) |
|
return (t1 + t2) / 2 |
|
|
|
|
|
class Decoder(nn.Module): |
|
|
|
def __init__(self, in_channels: int, |
|
out_channels: int): |
|
super().__init__() |
|
|
|
self._channels = { |
|
16: 128, |
|
32: 64, |
|
64: 64, |
|
128: 32, |
|
256: 16, |
|
512: 8, |
|
1024: 4, |
|
} |
|
|
|
self._layers = nn.Sequential( |
|
nn.AdaptiveAvgPool2d(output_size=8), |
|
UpsampleBlockT1(in_channels=in_channels, out_channels=self._channels[16]), |
|
UpsampleBlockT1(in_channels=self._channels[16], out_channels=self._channels[32]), |
|
UpsampleBlockT1(in_channels=self._channels[32], out_channels=self._channels[64]), |
|
UpsampleBlockT1(in_channels=self._channels[64], out_channels=self._channels[128]), |
|
SpectralConv2d( |
|
in_channels=self._channels[128], |
|
out_channels=out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding='same', |
|
bias=False, |
|
), |
|
nn.Tanh(), |
|
) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
return self._layers(input) |
|
|