Spaces:
Runtime error
Runtime error
from torch import nn | |
## LAYER UTILITIES #################################################################### | |
def deconv(in_channels, out_channels, kernel_size, stride=2, padding=1, norm=True, norm_mode='batch'): | |
"""Creates a transposed-convolutional layer, with optional batch/instance normalization. | |
""" | |
layers = [] | |
layers.append( | |
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) | |
) | |
if norm: | |
if norm_mode == 'instance': | |
layers.append( | |
nn.InstanceNorm2d(out_channels) | |
) | |
elif norm_mode == 'batch': | |
layers.append( | |
nn.BatchNorm2d(out_channels) | |
) | |
return nn.Sequential(*layers) | |
def conv(in_channels, out_channels, kernel_size, stride=2, padding=1, norm=True, norm_mode='batch'): | |
"""Creates a convolutional layer, with optional batch/instance normalization. | |
""" | |
layers = [] | |
layers.append( | |
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) | |
) | |
if norm: | |
if norm_mode == 'instance': | |
layers.append( | |
nn.InstanceNorm2d(out_channels) | |
) | |
elif norm_mode == 'batch': | |
layers.append( | |
nn.BatchNorm2d(out_channels) | |
) | |
return nn.Sequential(*layers) | |
class ResidualBlock(nn.Module): | |
"""Instatiates a residual block with kernel_size = 3 | |
""" | |
def __init__(self, conv_dim): | |
super(ResidualBlock, self).__init__() | |
self._conv = conv( | |
conv_dim, conv_dim, kernel_size=3, stride=1, padding=1 | |
) | |
def forward(self, x): | |
return x + self._conv(x) |