import torch.nn as nn import torch import torch.nn.functional as F # Dropout layer that works even in the evaluation mode class DropoutAlways(nn.Dropout2d): def forward(self, x): return F.dropout2d(x, self.p, training=True) class DownBlock(nn.Module): def __init__(self, in_channels, out_channels, normalize=True): super().__init__() self.block = nn.Sequential( nn.Conv2d(in_channels, out_channels, 4, 2, 1, padding_mode='reflect', bias=False if normalize else True), nn.InstanceNorm2d(out_channels, affine=True) if normalize else nn.Identity(), # Note that nn.Identity() is just a placeholder layer that returns its input. nn.LeakyReLU(0.2), ) def forward(self, x): return self.block(x) class UpBlock(nn.Module): def __init__(self, in_channels, out_channels, normalize=True, dropout=False, activation='relu'): super().__init__() self.block = nn.Sequential( nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False if normalize else True), nn.InstanceNorm2d(out_channels, affine=True) if normalize else nn.Identity(), DropoutAlways() if dropout else nn.Identity(), nn.ReLU() if activation == 'relu' else nn.Tanh(), ) def forward(self, x): return self.block(x) class Generator(nn.Module): def __init__(self): super().__init__() # Encoder self.encoder1 = DownBlock(1, 64, normalize=False) # 256x256 -> 128x128 self.encoder2 = DownBlock(64, 128) # 128x128 -> 64x64 self.encoder3 = DownBlock(128, 256) # 64x64 -> 32x32 self.encoder4 = DownBlock(256, 512) # 32x32 -> 16x16 self.encoder5 = DownBlock(512, 512) # 16x16 -> 8x8 self.encoder6 = DownBlock(512, 512) # 8x8 -> 4x4 self.encoder7 = DownBlock(512, 512) # 4x4 -> 2x2 self.encoder8 = DownBlock(512, 512, normalize=False) # 2x2 -> 1x1 # Decoder self.decoder1 = UpBlock(512, 512, dropout=True) # 1x1 -> 2x2 self.decoder2 = UpBlock(512 * 2, 512, dropout=True) # 2x2 -> 4x4 self.decoder3 = UpBlock(512 * 2, 512, dropout=True) # 4x4 -> 8x8 self.decoder4 = UpBlock(512 * 2, 512) # 8x8 -> 16x16 self.decoder5 = UpBlock(512 * 2, 256) # 16x16 -> 32x32 self.decoder6 = UpBlock(256 * 2, 128) # 32x32 -> 64x64 self.decoder7 = UpBlock(128 * 2, 64) # 64x64 -> 128x128 self.decoder8 = UpBlock(64 * 2, 2, normalize=False, activation='tanh') # 128x128 -> 256x256 def forward(self, x): # Encoder ch256_down = x ch128_down = self.encoder1(ch256_down) ch64_down = self.encoder2(ch128_down) ch32_down = self.encoder3(ch64_down) ch16_down = self.encoder4(ch32_down) ch8_down = self.encoder5(ch16_down) ch4_down = self.encoder6(ch8_down) ch2_down = self.encoder7(ch4_down) ch1 = self.encoder8(ch2_down) # Decoder ch2_up = self.decoder1(ch1) ch2 = torch.cat([ch2_up, ch2_down], dim=1) ch4_up = self.decoder2(ch2) ch4 = torch.cat([ch4_up, ch4_down], dim=1) ch8_up = self.decoder3(ch4) ch8 = torch.cat([ch8_up, ch8_down], dim=1) ch16_up = self.decoder4(ch8) ch16 = torch.cat([ch16_up, ch16_down], dim=1) ch32_up = self.decoder5(ch16) ch32 = torch.cat([ch32_up, ch32_down], dim=1) ch64_up = self.decoder6(ch32) ch64 = torch.cat([ch64_up, ch64_down], dim=1) ch128_up = self.decoder7(ch64) ch128 = torch.cat([ch128_up, ch128_down], dim=1) ch256_up = self.decoder8(ch128) return ch256_up