pix2pix-sar2rgb / networks.py
yuulind's picture
Initial commit
a664a45
raw
history blame
10.1 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from .layers import DownsamplingBlock, UpsamplingBlock
class UnetEncoder(nn.Module):
"""Create the Unet Encoder Network.
C64-C128-C256-C512-C512-C512-C512-C512
"""
def __init__(self, c_in=3, c_out=512):
"""
Constructs the Unet Encoder Network.
Ck denote a Convolution-BatchNorm-ReLU layer with k filters.
C64-C128-C256-C512-C512-C512-C512-C512
Args:
c_in (int, optional): Number of input channels.
c_out (int, optional): Number of output channels. Default is 512.
"""
super(UnetEncoder, self).__init__()
self.enc1 = DownsamplingBlock(c_in, 64, use_norm=False) # C64
self.enc2 = DownsamplingBlock(64, 128) # C128
self.enc3 = DownsamplingBlock(128, 256) # C256
self.enc4 = DownsamplingBlock(256, 512) # C512
self.enc5 = DownsamplingBlock(512, 512) # C512
self.enc6 = DownsamplingBlock(512, 512) # C512
self.enc7 = DownsamplingBlock(512, 512) # C512
self.enc8 = DownsamplingBlock(512, c_out) # C512
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(x1)
x3 = self.enc3(x2)
x4 = self.enc4(x3)
x5 = self.enc5(x4)
x6 = self.enc6(x5)
x7 = self.enc7(x6)
x8 = self.enc8(x7)
out = [x8, x7, x6, x5, x4, x3, x2, x1] # latest activation is the first element
return out
class UnetDecoder(nn.Module):
"""Creates the Unet Decoder Network.
"""
def __init__(self, c_in=512, c_out=64, use_upsampling=False, mode='nearest'):
"""
Constructs the Unet Decoder Network.
Ck denote a Convolution-BatchNorm-ReLU layer with k filters.
CDk denotes a Convolution-BatchNorm-Dropout-ReLU layer with a dropout rate of 50%.
CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
Args:
c_in (int): Number of input channels.
c_out (int, optional): Number of output channels. Default is 512.
use_upsampling (bool, optional): Upsampling method for decoder.
If True, use upsampling layer followed regular convolution layer.
If False, use transpose convolution. Default is False
mode (str, optional): the upsampling algorithm: one of 'nearest',
'bilinear', 'bicubic'. Default: 'nearest'
"""
super(UnetDecoder, self).__init__()
self.dec1 = UpsamplingBlock(c_in, 512, use_dropout=True, use_upsampling=use_upsampling, mode=mode) # CD512
self.dec2 = UpsamplingBlock(1024, 512, use_dropout=True, use_upsampling=use_upsampling, mode=mode) # CD1024
self.dec3 = UpsamplingBlock(1024, 512, use_dropout=True, use_upsampling=use_upsampling, mode=mode) # CD1024
self.dec4 = UpsamplingBlock(1024, 512, use_upsampling=use_upsampling, mode=mode) # C1024
self.dec5 = UpsamplingBlock(1024, 256, use_upsampling=use_upsampling, mode=mode) # C1024
self.dec6 = UpsamplingBlock(512, 128, use_upsampling=use_upsampling, mode=mode) # C512
self.dec7 = UpsamplingBlock(256, 64, use_upsampling=use_upsampling, mode=mode) # C256
self.dec8 = UpsamplingBlock(128, c_out, use_upsampling=use_upsampling, mode=mode) # C128
def forward(self, x):
x9 = torch.cat([x[1], self.dec1(x[0])], 1) # (N,1024,H,W)
x10 = torch.cat([x[2], self.dec2(x9)], 1) # (N,1024,H,W)
x11 = torch.cat([x[3], self.dec3(x10)], 1) # (N,1024,H,W)
x12 = torch.cat([x[4], self.dec4(x11)], 1) # (N,1024,H,W)
x13 = torch.cat([x[5], self.dec5(x12)], 1) # (N,512,H,W)
x14 = torch.cat([x[6], self.dec6(x13)], 1) # (N,256,H,W)
x15 = torch.cat([x[7], self.dec7(x14)], 1) # (N,128,H,W)
out = self.dec8(x15) # (N,64,H,W)
return out
class UnetGenerator(nn.Module):
"""Create a Unet-based generator"""
def __init__(self, c_in=3, c_out=3, use_upsampling=False, mode='nearest'):
"""
Constructs a Unet generator
Args:
c_in (int): The number of input channels.
c_out (int): The number of output channels.
use_upsampling (bool, optional): Upsampling method for decoder.
If True, use upsampling layer followed regular convolution layer.
If False, use transpose convolution. Default is False
mode (str, optional): the upsampling algorithm: one of 'nearest',
'bilinear', 'bicubic'. Default: 'nearest'
"""
super(UnetGenerator, self).__init__()
self.encoder = UnetEncoder(c_in=c_in)
self.decoder = UnetDecoder(use_upsampling=use_upsampling, mode=mode)
# In the paper, the authors state:
# """
# After the last layer in the decoder, a convolution is applied
# to map to the number of output channels (3 in general, except
# in colorization, where it is 2), followed by a Tanh function.
# """
# However, in the official Lua implementation, only a Tanh layer is applied.
# Therefore, I took the liberty of adding a convolutional layer with a
# kernel size of 3.
# For more information please check the paper and official github repo:
# https://github.com/phillipi/pix2pix
# https://arxiv.org/abs/1611.07004
self.head = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=c_out,
kernel_size=3, stride=1, padding=1,
bias=True
),
nn.Tanh()
)
def forward(self, x):
outE = self.encoder(x)
outD = self.decoder(outE)
out = self.head(outD)
return out
class PatchDiscriminator(nn.Module):
"""Create a PatchGAN discriminator"""
def __init__(self, c_in=3, c_hid=64, n_layers=3):
"""Constructs a PatchGAN discriminator
Args:
c_in (int, optional): The number of input channels. Defaults to 3.
c_hid (int, optional): The number of channels after first conv layer.
Defaults to 64.
n_layers (int, optional): the number of convolution blocks in the
discriminator. Defaults to 3.
"""
super(PatchDiscriminator, self).__init__()
model = [DownsamplingBlock(c_in, c_hid, use_norm=False)]
n_p = 1 # multiplier for previous channel
n_c = 1 # multiplier for current channel
# last block is with stride of 1, therefore iterate (n_layers-1) times
for n in range(1, n_layers):
n_p = n_c
n_c = min(2**n, 8) # The number of channels is 512 at most
model += [DownsamplingBlock(c_hid*n_p, c_hid*n_c)]
n_p = n_c
n_c = min(2**n_layers, 8)
model += [DownsamplingBlock(c_hid*n_p, c_hid*n_c, stride=1)] # last block is with stride of 1
# last layer is a convolution followed by a Sigmoid function.
model += [nn.Conv2d(in_channels=c_hid*n_c, out_channels=1,
kernel_size=4, stride=1, padding=1, bias=True
)]
# Normally, there should be a sigmoid layer at the end of discriminator.
# However, nn.BCEWithLogitsLoss combines the sigmoid layer with BCE loss,
# providing greater numerical stability. Therefore, the discriminator outputs
# logits to take advantage of this stability.
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class PixelDiscriminator(nn.Module):
"""Create a PixelGAN discriminator (1x1 PatchGAN discriminator)"""
def __init__(self, c_in=3, c_hid=64):
"""Constructs a PixelGAN discriminator, a special form of PatchGAN Discriminator.
All convolutions are 1x1 spatial filters
Args:
c_in (int, optional): The number of input channels. Defaults to 3.
c_hid (int, optional): The number of channels after first conv layer.
Defaults to 64.
"""
super(PixelDiscriminator, self).__init__()
self.model = nn.Sequential(
DownsamplingBlock(c_in, c_hid, kernel_size=1, stride=1, padding=0, use_norm=False),
DownsamplingBlock(c_hid, c_hid*2, kernel_size=1, stride=1, padding=0),
nn.Conv2d(in_channels=c_hid*2, out_channels=1, kernel_size=1)
)
# Similar to PatchDiscriminator, there should be a sigmoid layer at the end of discriminator.
# However, nn.BCEWithLogitsLoss combines the sigmoid layer with BCE loss,
# providing greater numerical stability. Therefore, the discriminator outputs
# logits to take advantage of this stability.
def forward(self, x):
return self.model(x)
class PatchGAN(nn.Module):
"""Create a PatchGAN discriminator"""
def __init__(self, c_in=3, c_hid=64, mode='patch', n_layers=3):
"""Constructs a PatchGAN discriminator.
Args:
c_in (int, optional): The number of input channels. Defaults to 3.
c_hid (int, optional): The number of channels after first
convolutional layer. Defaults to 64.
mode (str, optional): PatchGAN type. Use 'pixel' for PixelGAN, and
'patch' for other types. Defaults to 'patch'.
n_layers (int, optional): PatchGAN number of layers. Defaults to 3.
- 16x16 PatchGAN if n=1
- 34x34 PatchGAN if n=2
- 70x70 PatchGAN if n=3
- 142x142 PatchGAN if n=4
- 286x286 PatchGAN if n=5
- 574x574 PatchGAN if n=6
"""
super(PatchGAN, self).__init__()
if mode == 'pixel':
self.model = PixelDiscriminator(c_in, c_hid)
else:
self.model = PatchDiscriminator(c_in, c_hid, n_layers)
def forward(self, x):
return self.model(x)