import torch import torch.nn as nn class DownBlock(nn.Module): def __init__(self, in_filters, out_filters, normal=True): super().__init__() layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1, padding_mode='reflect', bias=not normal)] if normal: layers.append(nn.InstanceNorm2d(out_filters, affine=True)) layers.append(nn.LeakyReLU(0.2, inplace=True)) self.block = nn.Sequential(*layers) def forward(self, x): return self.block(x) class UpBlock(nn.Module): def __init__(self, in_filters, out_filters, dropout=0.0): super().__init__() layers = [ nn.ConvTranspose2d(in_filters, out_filters, 4, 2, 1, bias=False), nn.InstanceNorm2d(out_filters, affine=True), nn.ReLU(inplace=True), ] if dropout: layers.append(nn.Dropout(dropout)) self.model = nn.Sequential(*layers) def forward(self, x, skip_input): x = self.model(x) x = torch.cat((x, skip_input), 1) return x class Generator(nn.Module): def __init__(self, input_channels, features=[64, 128, 256, 512, 512, 512, 512]): super().__init__() self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() for idx, feature in enumerate(features): if idx == 0: self.encoder.append(DownBlock(input_channels, feature, normal=False)) else: self.encoder.append(DownBlock(input_channels, feature)) input_channels = feature self.bottleneck = DownBlock(512, 512, normal=False) self.final = nn.Sequential( nn.ConvTranspose2d(128, 3, 4, 2, 1), nn.Tanh() ) input_channels = features[-1] for idx, feature in enumerate(reversed(features)): if idx == 0: self.decoder.append(UpBlock(input_channels, feature, dropout=0.5)) elif idx < 3: self.decoder.append(UpBlock(input_channels*2, feature, dropout=0.5)) else: self.decoder.append(UpBlock(input_channels*2, feature)) input_channels = feature def forward(self, x): skips = [] for layer in self.encoder: x = layer(x) skips.append(x) x = self.bottleneck(x) skips = skips[::-1] for idx, layer in enumerate(self.decoder): x = layer(x, skips[idx]) x = self.final(x) return x