Spaces:
Runtime error
Runtime error
File size: 1,701 Bytes
a847ff6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from torch import nn
## LAYER UTILITIES ####################################################################
def deconv(in_channels, out_channels, kernel_size, stride=2, padding=1, norm=True, norm_mode='batch'):
"""Creates a transposed-convolutional layer, with optional batch/instance normalization.
"""
layers = []
layers.append(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
)
if norm:
if norm_mode == 'instance':
layers.append(
nn.InstanceNorm2d(out_channels)
)
elif norm_mode == 'batch':
layers.append(
nn.BatchNorm2d(out_channels)
)
return nn.Sequential(*layers)
def conv(in_channels, out_channels, kernel_size, stride=2, padding=1, norm=True, norm_mode='batch'):
"""Creates a convolutional layer, with optional batch/instance normalization.
"""
layers = []
layers.append(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
)
if norm:
if norm_mode == 'instance':
layers.append(
nn.InstanceNorm2d(out_channels)
)
elif norm_mode == 'batch':
layers.append(
nn.BatchNorm2d(out_channels)
)
return nn.Sequential(*layers)
class ResidualBlock(nn.Module):
"""Instatiates a residual block with kernel_size = 3
"""
def __init__(self, conv_dim):
super(ResidualBlock, self).__init__()
self._conv = conv(
conv_dim, conv_dim, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
return x + self._conv(x) |