geninhu's picture
Upload models.py
d6c219d
import torch
import torch.nn as nn
from typing import Any, Tuple, Union
from utils import (
ImageType,
crop_image_part,
)
from layers import (
SpectralConv2d,
InitLayer,
SLEBlock,
UpsampleBlockT1,
UpsampleBlockT2,
DownsampleBlockT1,
DownsampleBlockT2,
Decoder,
)
from huggan.pytorch.huggan_mixin import HugGANModelHubMixin
class Generator(nn.Module, HugGANModelHubMixin):
def __init__(self, in_channels: int,
out_channels: int):
super().__init__()
self._channels = {
4: 1024,
8: 512,
16: 256,
32: 128,
64: 128,
128: 64,
256: 32,
512: 16,
1024: 8,
}
self._init = InitLayer(
in_channels=in_channels,
out_channels=self._channels[4],
)
self._upsample_8 = UpsampleBlockT2(in_channels=self._channels[4], out_channels=self._channels[8] )
self._upsample_16 = UpsampleBlockT1(in_channels=self._channels[8], out_channels=self._channels[16] )
self._upsample_32 = UpsampleBlockT2(in_channels=self._channels[16], out_channels=self._channels[32] )
self._upsample_64 = UpsampleBlockT1(in_channels=self._channels[32], out_channels=self._channels[64] )
self._upsample_128 = UpsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[128] )
self._upsample_256 = UpsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[256] )
self._upsample_512 = UpsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[512] )
self._upsample_1024 = UpsampleBlockT1(in_channels=self._channels[512], out_channels=self._channels[1024])
self._sle_64 = SLEBlock(in_channels=self._channels[4], out_channels=self._channels[64] )
self._sle_128 = SLEBlock(in_channels=self._channels[8], out_channels=self._channels[128])
self._sle_256 = SLEBlock(in_channels=self._channels[16], out_channels=self._channels[256])
self._sle_512 = SLEBlock(in_channels=self._channels[32], out_channels=self._channels[512])
self._out_128 = nn.Sequential(
SpectralConv2d(
in_channels=self._channels[128],
out_channels=out_channels,
kernel_size=1,
stride=1,
padding='same',
bias=False,
),
nn.Tanh(),
)
self._out_1024 = nn.Sequential(
SpectralConv2d(
in_channels=self._channels[1024],
out_channels=out_channels,
kernel_size=3,
stride=1,
padding='same',
bias=False,
),
nn.Tanh(),
)
def forward(self, input: torch.Tensor) -> \
Tuple[torch.Tensor, torch.Tensor]:
size_4 = self._init(input)
size_8 = self._upsample_8(size_4)
size_16 = self._upsample_16(size_8)
size_32 = self._upsample_32(size_16)
size_64 = self._sle_64 (size_4, self._upsample_64 (size_32) )
size_128 = self._sle_128(size_8, self._upsample_128(size_64) )
size_256 = self._sle_256(size_16, self._upsample_256(size_128))
size_512 = self._sle_512(size_32, self._upsample_512(size_256))
size_1024 = self._upsample_1024(size_512)
out_128 = self._out_128 (size_128)
out_1024 = self._out_1024(size_1024)
return out_1024, out_128
class Discriminrator(nn.Module, HugGANModelHubMixin):
def __init__(self, in_channels: int):
super().__init__()
self._channels = {
4: 1024,
8: 512,
16: 256,
32: 128,
64: 128,
128: 64,
256: 32,
512: 16,
1024: 8,
}
self._init = nn.Sequential(
SpectralConv2d(
in_channels=in_channels,
out_channels=self._channels[1024],
kernel_size=4,
stride=2,
padding=1,
bias=False,
),
nn.LeakyReLU(negative_slope=0.2),
SpectralConv2d(
in_channels=self._channels[1024],
out_channels=self._channels[512],
kernel_size=4,
stride=2,
padding=1,
bias=False,
),
nn.BatchNorm2d(num_features=self._channels[512]),
nn.LeakyReLU(negative_slope=0.2),
)
self._downsample_256 = DownsampleBlockT2(in_channels=self._channels[512], out_channels=self._channels[256])
self._downsample_128 = DownsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[128])
self._downsample_64 = DownsampleBlockT2(in_channels=self._channels[128], out_channels=self._channels[64] )
self._downsample_32 = DownsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[32] )
self._downsample_16 = DownsampleBlockT2(in_channels=self._channels[32], out_channels=self._channels[16] )
self._sle_64 = SLEBlock(in_channels=self._channels[512], out_channels=self._channels[64])
self._sle_32 = SLEBlock(in_channels=self._channels[256], out_channels=self._channels[32])
self._sle_16 = SLEBlock(in_channels=self._channels[128], out_channels=self._channels[16])
self._small_track = nn.Sequential(
SpectralConv2d(
in_channels=in_channels,
out_channels=self._channels[256],
kernel_size=4,
stride=2,
padding=1,
bias=False,
),
nn.LeakyReLU(negative_slope=0.2),
DownsampleBlockT1(in_channels=self._channels[256], out_channels=self._channels[128]),
DownsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[64] ),
DownsampleBlockT1(in_channels=self._channels[64], out_channels=self._channels[32] ),
)
self._features_large = nn.Sequential(
SpectralConv2d(
in_channels=self._channels[16] ,
out_channels=self._channels[8],
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(num_features=self._channels[8]),
nn.LeakyReLU(negative_slope=0.2),
SpectralConv2d(
in_channels=self._channels[8],
out_channels=1,
kernel_size=4,
stride=1,
padding=0,
bias=False,
)
)
self._features_small = nn.Sequential(
SpectralConv2d(
in_channels=self._channels[32],
out_channels=1,
kernel_size=4,
stride=1,
padding=0,
bias=False,
),
)
self._decoder_large = Decoder(in_channels=self._channels[16], out_channels=3)
self._decoder_small = Decoder(in_channels=self._channels[32], out_channels=3)
self._decoder_piece = Decoder(in_channels=self._channels[32], out_channels=3)
def forward(self, images_1024: torch.Tensor,
images_128: torch.Tensor,
image_type: ImageType) -> \
Union[
torch.Tensor,
Tuple[torch.Tensor, Tuple[Any, Any, Any]]
]:
# large track
down_512 = self._init(images_1024)
down_256 = self._downsample_256(down_512)
down_128 = self._downsample_128(down_256)
down_64 = self._downsample_64(down_128)
down_64 = self._sle_64(down_512, down_64)
down_32 = self._downsample_32(down_64)
down_32 = self._sle_32(down_256, down_32)
down_16 = self._downsample_16(down_32)
down_16 = self._sle_16(down_128, down_16)
# small track
down_small = self._small_track(images_128)
# features
features_large = self._features_large(down_16).view(-1)
features_small = self._features_small(down_small).view(-1)
features = torch.cat([features_large, features_small], dim=0)
# decoder
if image_type != ImageType.FAKE:
dec_large = self._decoder_large(down_16)
dec_small = self._decoder_small(down_small)
dec_piece = self._decoder_piece(crop_image_part(down_32, image_type))
return features, (dec_large, dec_small, dec_piece)
return features