Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from typing import Any, Tuple, Union | |
from utils import ( | |
ImageType, | |
crop_image_part, | |
) | |
from layers import ( | |
SpectralConv2d, | |
InitLayer, | |
SLEBlock, | |
UpsampleBlockT1, | |
UpsampleBlockT2, | |
DownsampleBlockT1, | |
DownsampleBlockT2, | |
Decoder, | |
) | |
from huggan.pytorch.huggan_mixin import HugGANModelHubMixin | |
class Generator(nn.Module, HugGANModelHubMixin): | |
def __init__(self, in_channels: int, | |
out_channels: int): | |
super().__init__() | |
self._channels = { | |
4: 1024, | |
8: 512, | |
16: 256, | |
32: 128, | |
64: 128, | |
128: 64, | |
256: 32, | |
512: 16, | |
1024: 8, | |
} | |
self._init = InitLayer( | |
in_channels=in_channels, | |
out_channels=self._channels[4], | |
) | |
self._upsample_8 = UpsampleBlockT2(in_channels=self._channels[4], out_channels=self._channels[8] ) | |
self._upsample_16 = UpsampleBlockT1(in_channels=self._channels[8], out_channels=self._channels[16] ) | |
self._upsample_32 = UpsampleBlockT2(in_channels=self._channels[16], out_channels=self._channels[32] ) | |
self._upsample_64 = UpsampleBlockT1(in_channels=self._channels[32], out_channels=self._channels[64] ) | |
self._upsample_128 = UpsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[128] ) | |
self._upsample_256 = UpsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[256] ) | |
self._upsample_512 = UpsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[512] ) | |
self._upsample_1024 = UpsampleBlockT1(in_channels=self._channels[512], out_channels=self._channels[1024]) | |
self._sle_64 = SLEBlock(in_channels=self._channels[4], out_channels=self._channels[64] ) | |
self._sle_128 = SLEBlock(in_channels=self._channels[8], out_channels=self._channels[128]) | |
self._sle_256 = SLEBlock(in_channels=self._channels[16], out_channels=self._channels[256]) | |
self._sle_512 = SLEBlock(in_channels=self._channels[32], out_channels=self._channels[512]) | |
self._out_128 = nn.Sequential( | |
SpectralConv2d( | |
in_channels=self._channels[128], | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=1, | |
padding='same', | |
bias=False, | |
), | |
nn.Tanh(), | |
) | |
self._out_1024 = nn.Sequential( | |
SpectralConv2d( | |
in_channels=self._channels[1024], | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=1, | |
padding='same', | |
bias=False, | |
), | |
nn.Tanh(), | |
) | |
def forward(self, input: torch.Tensor) -> \ | |
Tuple[torch.Tensor, torch.Tensor]: | |
size_4 = self._init(input) | |
size_8 = self._upsample_8(size_4) | |
size_16 = self._upsample_16(size_8) | |
size_32 = self._upsample_32(size_16) | |
size_64 = self._sle_64 (size_4, self._upsample_64 (size_32) ) | |
size_128 = self._sle_128(size_8, self._upsample_128(size_64) ) | |
size_256 = self._sle_256(size_16, self._upsample_256(size_128)) | |
size_512 = self._sle_512(size_32, self._upsample_512(size_256)) | |
size_1024 = self._upsample_1024(size_512) | |
out_128 = self._out_128 (size_128) | |
out_1024 = self._out_1024(size_1024) | |
return out_1024, out_128 | |
class Discriminrator(nn.Module, HugGANModelHubMixin): | |
def __init__(self, in_channels: int): | |
super().__init__() | |
self._channels = { | |
4: 1024, | |
8: 512, | |
16: 256, | |
32: 128, | |
64: 128, | |
128: 64, | |
256: 32, | |
512: 16, | |
1024: 8, | |
} | |
self._init = nn.Sequential( | |
SpectralConv2d( | |
in_channels=in_channels, | |
out_channels=self._channels[1024], | |
kernel_size=4, | |
stride=2, | |
padding=1, | |
bias=False, | |
), | |
nn.LeakyReLU(negative_slope=0.2), | |
SpectralConv2d( | |
in_channels=self._channels[1024], | |
out_channels=self._channels[512], | |
kernel_size=4, | |
stride=2, | |
padding=1, | |
bias=False, | |
), | |
nn.BatchNorm2d(num_features=self._channels[512]), | |
nn.LeakyReLU(negative_slope=0.2), | |
) | |
self._downsample_256 = DownsampleBlockT2(in_channels=self._channels[512], out_channels=self._channels[256]) | |
self._downsample_128 = DownsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[128]) | |
self._downsample_64 = DownsampleBlockT2(in_channels=self._channels[128], out_channels=self._channels[64] ) | |
self._downsample_32 = DownsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[32] ) | |
self._downsample_16 = DownsampleBlockT2(in_channels=self._channels[32], out_channels=self._channels[16] ) | |
self._sle_64 = SLEBlock(in_channels=self._channels[512], out_channels=self._channels[64]) | |
self._sle_32 = SLEBlock(in_channels=self._channels[256], out_channels=self._channels[32]) | |
self._sle_16 = SLEBlock(in_channels=self._channels[128], out_channels=self._channels[16]) | |
self._small_track = nn.Sequential( | |
SpectralConv2d( | |
in_channels=in_channels, | |
out_channels=self._channels[256], | |
kernel_size=4, | |
stride=2, | |
padding=1, | |
bias=False, | |
), | |
nn.LeakyReLU(negative_slope=0.2), | |
DownsampleBlockT1(in_channels=self._channels[256], out_channels=self._channels[128]), | |
DownsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[64] ), | |
DownsampleBlockT1(in_channels=self._channels[64], out_channels=self._channels[32] ), | |
) | |
self._features_large = nn.Sequential( | |
SpectralConv2d( | |
in_channels=self._channels[16] , | |
out_channels=self._channels[8], | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False, | |
), | |
nn.BatchNorm2d(num_features=self._channels[8]), | |
nn.LeakyReLU(negative_slope=0.2), | |
SpectralConv2d( | |
in_channels=self._channels[8], | |
out_channels=1, | |
kernel_size=4, | |
stride=1, | |
padding=0, | |
bias=False, | |
) | |
) | |
self._features_small = nn.Sequential( | |
SpectralConv2d( | |
in_channels=self._channels[32], | |
out_channels=1, | |
kernel_size=4, | |
stride=1, | |
padding=0, | |
bias=False, | |
), | |
) | |
self._decoder_large = Decoder(in_channels=self._channels[16], out_channels=3) | |
self._decoder_small = Decoder(in_channels=self._channels[32], out_channels=3) | |
self._decoder_piece = Decoder(in_channels=self._channels[32], out_channels=3) | |
def forward(self, images_1024: torch.Tensor, | |
images_128: torch.Tensor, | |
image_type: ImageType) -> \ | |
Union[ | |
torch.Tensor, | |
Tuple[torch.Tensor, Tuple[Any, Any, Any]] | |
]: | |
# large track | |
down_512 = self._init(images_1024) | |
down_256 = self._downsample_256(down_512) | |
down_128 = self._downsample_128(down_256) | |
down_64 = self._downsample_64(down_128) | |
down_64 = self._sle_64(down_512, down_64) | |
down_32 = self._downsample_32(down_64) | |
down_32 = self._sle_32(down_256, down_32) | |
down_16 = self._downsample_16(down_32) | |
down_16 = self._sle_16(down_128, down_16) | |
# small track | |
down_small = self._small_track(images_128) | |
# features | |
features_large = self._features_large(down_16).view(-1) | |
features_small = self._features_small(down_small).view(-1) | |
features = torch.cat([features_large, features_small], dim=0) | |
# decoder | |
if image_type != ImageType.FAKE: | |
dec_large = self._decoder_large(down_16) | |
dec_small = self._decoder_small(down_small) | |
dec_piece = self._decoder_piece(crop_image_part(down_32, image_type)) | |
return features, (dec_large, dec_small, dec_piece) | |
return features | |