from torch import nn ## LAYER UTILITIES #################################################################### def deconv(in_channels, out_channels, kernel_size, stride=2, padding=1, norm=True, norm_mode='batch'): """Creates a transposed-convolutional layer, with optional batch/instance normalization. """ layers = [] layers.append( nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) ) if norm: if norm_mode == 'instance': layers.append( nn.InstanceNorm2d(out_channels) ) elif norm_mode == 'batch': layers.append( nn.BatchNorm2d(out_channels) ) return nn.Sequential(*layers) def conv(in_channels, out_channels, kernel_size, stride=2, padding=1, norm=True, norm_mode='batch'): """Creates a convolutional layer, with optional batch/instance normalization. """ layers = [] layers.append( nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) ) if norm: if norm_mode == 'instance': layers.append( nn.InstanceNorm2d(out_channels) ) elif norm_mode == 'batch': layers.append( nn.BatchNorm2d(out_channels) ) return nn.Sequential(*layers) class ResidualBlock(nn.Module): """Instatiates a residual block with kernel_size = 3 """ def __init__(self, conv_dim): super(ResidualBlock, self).__init__() self._conv = conv( conv_dim, conv_dim, kernel_size=3, stride=1, padding=1 ) def forward(self, x): return x + self._conv(x)