import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.batchnorm import BatchNorm2d from torch.nn.utils import spectral_norm class SpectralConv2d(nn.Module): def __init__(self, *args, **kwargs): super().__init__() self._conv = spectral_norm( nn.Conv2d(*args, **kwargs) ) def forward(self, input: torch.Tensor) -> torch.Tensor: return self._conv(input) class SpectralConvTranspose2d(nn.Module): def __init__(self, *args, **kwargs): super().__init__() self._conv = spectral_norm( nn.ConvTranspose2d(*args, **kwargs) ) def forward(self, input: torch.Tensor) -> torch.Tensor: return self._conv(input) class Noise(nn.Module): def __init__(self): super().__init__() self._weight = nn.Parameter( torch.zeros(1), requires_grad=True, ) def forward(self, input: torch.Tensor) -> torch.Tensor: batch_size, _, height, width = input.shape noise = torch.randn(batch_size, 1, height, width, device=input.device) return self._weight * noise + input class InitLayer(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self._layers = nn.Sequential( SpectralConvTranspose2d( in_channels=in_channels, out_channels=out_channels * 2, kernel_size=4, stride=1, padding=0, bias=False, ), nn.BatchNorm2d(num_features=out_channels * 2), nn.GLU(dim=1), ) def forward(self, input: torch.Tensor) -> torch.Tensor: return self._layers(input) class SLEBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self._layers = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=4), SpectralConv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=1, padding=0, bias=False, ), nn.SiLU(), SpectralConv2d( in_channels=out_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=False, ), nn.Sigmoid(), ) def forward(self, low_dim: torch.Tensor, high_dim: torch.Tensor) -> torch.Tensor: return high_dim * self._layers(low_dim) class UpsampleBlockT1(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self._layers = nn.Sequential( nn.Upsample(scale_factor=2, mode='nearest'), SpectralConv2d( in_channels=in_channels, out_channels=out_channels * 2, kernel_size=3, stride=1, padding='same', bias=False, ), nn.BatchNorm2d(num_features=out_channels * 2), nn.GLU(dim=1), ) def forward(self, input: torch.Tensor) -> torch.Tensor: return self._layers(input) class UpsampleBlockT2(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self._layers = nn.Sequential( nn.Upsample(scale_factor=2, mode='nearest'), SpectralConv2d( in_channels=in_channels, out_channels=out_channels * 2, kernel_size=3, stride=1, padding='same', bias=False, ), Noise(), BatchNorm2d(num_features=out_channels * 2), nn.GLU(dim=1), SpectralConv2d( in_channels=out_channels, out_channels=out_channels * 2, kernel_size=3, stride=1, padding='same', bias=False, ), Noise(), nn.BatchNorm2d(num_features=out_channels * 2), nn.GLU(dim=1), ) def forward(self, input: torch.Tensor) -> torch.Tensor: return self._layers(input) class DownsampleBlockT1(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self._layers = nn.Sequential( SpectralConv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=False, ), nn.BatchNorm2d(num_features=out_channels), nn.LeakyReLU(negative_slope=0.2), ) def forward(self, input: torch.Tensor) -> torch.Tensor: return self._layers(input) class DownsampleBlockT2(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self._layers_1 = nn.Sequential( SpectralConv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=False, ), nn.BatchNorm2d(num_features=out_channels), nn.LeakyReLU(negative_slope=0.2), SpectralConv2d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding='same', bias=False, ), nn.BatchNorm2d(num_features=out_channels), nn.LeakyReLU(negative_slope=0.2), ) self._layers_2 = nn.Sequential( nn.AvgPool2d( kernel_size=2, stride=2, ), SpectralConv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=False, ), nn.BatchNorm2d(num_features=out_channels), nn.LeakyReLU(negative_slope=0.2), ) def forward(self, input: torch.Tensor) -> torch.Tensor: t1 = self._layers_1(input) t2 = self._layers_2(input) return (t1 + t2) / 2 class Decoder(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self._channels = { 16: 128, 32: 64, 64: 64, 128: 32, 256: 16, 512: 8, 1024: 4, } self._layers = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=8), UpsampleBlockT1(in_channels=in_channels, out_channels=self._channels[16]), UpsampleBlockT1(in_channels=self._channels[16], out_channels=self._channels[32]), UpsampleBlockT1(in_channels=self._channels[32], out_channels=self._channels[64]), UpsampleBlockT1(in_channels=self._channels[64], out_channels=self._channels[128]), SpectralConv2d( in_channels=self._channels[128], out_channels=out_channels, kernel_size=3, stride=1, padding='same', bias=False, ), nn.Tanh(), ) def forward(self, input: torch.Tensor) -> torch.Tensor: return self._layers(input)