Minuano's picture
Add applications files and generator saved models.
a847ff6
from torch import nn
## LAYER UTILITIES ####################################################################
def deconv(in_channels, out_channels, kernel_size, stride=2, padding=1, norm=True, norm_mode='batch'):
"""Creates a transposed-convolutional layer, with optional batch/instance normalization.
"""
layers = []
layers.append(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
)
if norm:
if norm_mode == 'instance':
layers.append(
nn.InstanceNorm2d(out_channels)
)
elif norm_mode == 'batch':
layers.append(
nn.BatchNorm2d(out_channels)
)
return nn.Sequential(*layers)
def conv(in_channels, out_channels, kernel_size, stride=2, padding=1, norm=True, norm_mode='batch'):
"""Creates a convolutional layer, with optional batch/instance normalization.
"""
layers = []
layers.append(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
)
if norm:
if norm_mode == 'instance':
layers.append(
nn.InstanceNorm2d(out_channels)
)
elif norm_mode == 'batch':
layers.append(
nn.BatchNorm2d(out_channels)
)
return nn.Sequential(*layers)
class ResidualBlock(nn.Module):
"""Instatiates a residual block with kernel_size = 3
"""
def __init__(self, conv_dim):
super(ResidualBlock, self).__init__()
self._conv = conv(
conv_dim, conv_dim, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
return x + self._conv(x)