import torch import torch.nn as nn from typing import Any, Tuple, Union from utils import ( ImageType, crop_image_part, ) from layers import ( SpectralConv2d, InitLayer, SLEBlock, UpsampleBlockT1, UpsampleBlockT2, DownsampleBlockT1, DownsampleBlockT2, Decoder, ) from huggan.pytorch.huggan_mixin import HugGANModelHubMixin class Generator(nn.Module, HugGANModelHubMixin): def __init__(self, in_channels: int, out_channels: int): super().__init__() self._channels = { 4: 1024, 8: 512, 16: 256, 32: 128, 64: 128, 128: 64, 256: 32, 512: 16, 1024: 8, } self._init = InitLayer( in_channels=in_channels, out_channels=self._channels[4], ) self._upsample_8 = UpsampleBlockT2(in_channels=self._channels[4], out_channels=self._channels[8] ) self._upsample_16 = UpsampleBlockT1(in_channels=self._channels[8], out_channels=self._channels[16] ) self._upsample_32 = UpsampleBlockT2(in_channels=self._channels[16], out_channels=self._channels[32] ) self._upsample_64 = UpsampleBlockT1(in_channels=self._channels[32], out_channels=self._channels[64] ) self._upsample_128 = UpsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[128] ) self._upsample_256 = UpsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[256] ) self._upsample_512 = UpsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[512] ) self._upsample_1024 = UpsampleBlockT1(in_channels=self._channels[512], out_channels=self._channels[1024]) self._sle_64 = SLEBlock(in_channels=self._channels[4], out_channels=self._channels[64] ) self._sle_128 = SLEBlock(in_channels=self._channels[8], out_channels=self._channels[128]) self._sle_256 = SLEBlock(in_channels=self._channels[16], out_channels=self._channels[256]) self._sle_512 = SLEBlock(in_channels=self._channels[32], out_channels=self._channels[512]) self._out_128 = nn.Sequential( SpectralConv2d( in_channels=self._channels[128], out_channels=out_channels, kernel_size=1, stride=1, padding='same', bias=False, ), nn.Tanh(), ) self._out_1024 = nn.Sequential( SpectralConv2d( in_channels=self._channels[1024], out_channels=out_channels, kernel_size=3, stride=1, padding='same', bias=False, ), nn.Tanh(), ) def forward(self, input: torch.Tensor) -> \ Tuple[torch.Tensor, torch.Tensor]: size_4 = self._init(input) size_8 = self._upsample_8(size_4) size_16 = self._upsample_16(size_8) size_32 = self._upsample_32(size_16) size_64 = self._sle_64 (size_4, self._upsample_64 (size_32) ) size_128 = self._sle_128(size_8, self._upsample_128(size_64) ) size_256 = self._sle_256(size_16, self._upsample_256(size_128)) size_512 = self._sle_512(size_32, self._upsample_512(size_256)) size_1024 = self._upsample_1024(size_512) out_128 = self._out_128 (size_128) out_1024 = self._out_1024(size_1024) return out_1024, out_128 class Discriminrator(nn.Module, HugGANModelHubMixin): def __init__(self, in_channels: int): super().__init__() self._channels = { 4: 1024, 8: 512, 16: 256, 32: 128, 64: 128, 128: 64, 256: 32, 512: 16, 1024: 8, } self._init = nn.Sequential( SpectralConv2d( in_channels=in_channels, out_channels=self._channels[1024], kernel_size=4, stride=2, padding=1, bias=False, ), nn.LeakyReLU(negative_slope=0.2), SpectralConv2d( in_channels=self._channels[1024], out_channels=self._channels[512], kernel_size=4, stride=2, padding=1, bias=False, ), nn.BatchNorm2d(num_features=self._channels[512]), nn.LeakyReLU(negative_slope=0.2), ) self._downsample_256 = DownsampleBlockT2(in_channels=self._channels[512], out_channels=self._channels[256]) self._downsample_128 = DownsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[128]) self._downsample_64 = DownsampleBlockT2(in_channels=self._channels[128], out_channels=self._channels[64] ) self._downsample_32 = DownsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[32] ) self._downsample_16 = DownsampleBlockT2(in_channels=self._channels[32], out_channels=self._channels[16] ) self._sle_64 = SLEBlock(in_channels=self._channels[512], out_channels=self._channels[64]) self._sle_32 = SLEBlock(in_channels=self._channels[256], out_channels=self._channels[32]) self._sle_16 = SLEBlock(in_channels=self._channels[128], out_channels=self._channels[16]) self._small_track = nn.Sequential( SpectralConv2d( in_channels=in_channels, out_channels=self._channels[256], kernel_size=4, stride=2, padding=1, bias=False, ), nn.LeakyReLU(negative_slope=0.2), DownsampleBlockT1(in_channels=self._channels[256], out_channels=self._channels[128]), DownsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[64] ), DownsampleBlockT1(in_channels=self._channels[64], out_channels=self._channels[32] ), ) self._features_large = nn.Sequential( SpectralConv2d( in_channels=self._channels[16] , out_channels=self._channels[8], kernel_size=1, stride=1, padding=0, bias=False, ), nn.BatchNorm2d(num_features=self._channels[8]), nn.LeakyReLU(negative_slope=0.2), SpectralConv2d( in_channels=self._channels[8], out_channels=1, kernel_size=4, stride=1, padding=0, bias=False, ) ) self._features_small = nn.Sequential( SpectralConv2d( in_channels=self._channels[32], out_channels=1, kernel_size=4, stride=1, padding=0, bias=False, ), ) self._decoder_large = Decoder(in_channels=self._channels[16], out_channels=3) self._decoder_small = Decoder(in_channels=self._channels[32], out_channels=3) self._decoder_piece = Decoder(in_channels=self._channels[32], out_channels=3) def forward(self, images_1024: torch.Tensor, images_128: torch.Tensor, image_type: ImageType) -> \ Union[ torch.Tensor, Tuple[torch.Tensor, Tuple[Any, Any, Any]] ]: # large track down_512 = self._init(images_1024) down_256 = self._downsample_256(down_512) down_128 = self._downsample_128(down_256) down_64 = self._downsample_64(down_128) down_64 = self._sle_64(down_512, down_64) down_32 = self._downsample_32(down_64) down_32 = self._sle_32(down_256, down_32) down_16 = self._downsample_16(down_32) down_16 = self._sle_16(down_128, down_16) # small track down_small = self._small_track(images_128) # features features_large = self._features_large(down_16).view(-1) features_small = self._features_small(down_small).view(-1) features = torch.cat([features_large, features_small], dim=0) # decoder if image_type != ImageType.FAKE: dec_large = self._decoder_large(down_16) dec_small = self._decoder_small(down_small) dec_piece = self._decoder_piece(crop_image_part(down_32, image_type)) return features, (dec_large, dec_small, dec_piece) return features