Spaces:
Runtime error
Runtime error
| import torch | |
| class TransformerNet(torch.nn.Module): | |
| def __init__(self): | |
| super(TransformerNet, self).__init__() | |
| # Initial convolution layers | |
| self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1) | |
| self.in1 = torch.nn.InstanceNorm2d(32, affine=True) | |
| self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2) | |
| self.in2 = torch.nn.InstanceNorm2d(64, affine=True) | |
| self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2) | |
| self.in3 = torch.nn.InstanceNorm2d(128, affine=True) | |
| # Residual layers | |
| self.res1 = ResidualBlock(128) | |
| self.res2 = ResidualBlock(128) | |
| self.res3 = ResidualBlock(128) | |
| self.res4 = ResidualBlock(128) | |
| self.res5 = ResidualBlock(128) | |
| # Upsampling Layers | |
| self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2) | |
| self.in4 = torch.nn.InstanceNorm2d(64, affine=True) | |
| self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2) | |
| self.in5 = torch.nn.InstanceNorm2d(32, affine=True) | |
| self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1) | |
| # Non-linearities | |
| self.relu = torch.nn.ReLU() | |
| def forward(self, X): | |
| y = self.relu(self.in1(self.conv1(X))) | |
| y = self.relu(self.in2(self.conv2(y))) | |
| y = self.relu(self.in3(self.conv3(y))) | |
| y = self.res1(y) | |
| y = self.res2(y) | |
| y = self.res3(y) | |
| y = self.res4(y) | |
| y = self.res5(y) | |
| y = self.relu(self.in4(self.deconv1(y))) | |
| y = self.relu(self.in5(self.deconv2(y))) | |
| y = self.deconv3(y) | |
| return y | |
| class ConvLayer(torch.nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride): | |
| super(ConvLayer, self).__init__() | |
| reflection_padding = kernel_size // 2 | |
| self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) | |
| self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride) | |
| def forward(self, x): | |
| out = self.reflection_pad(x) | |
| out = self.conv2d(out) | |
| return out | |
| class ResidualBlock(torch.nn.Module): | |
| """ResidualBlock | |
| introduced in: https://arxiv.org/abs/1512.03385 | |
| recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html | |
| """ | |
| def __init__(self, channels): | |
| super(ResidualBlock, self).__init__() | |
| self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1) | |
| self.in1 = torch.nn.InstanceNorm2d(channels, affine=True) | |
| self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1) | |
| self.in2 = torch.nn.InstanceNorm2d(channels, affine=True) | |
| self.relu = torch.nn.ReLU() | |
| def forward(self, x): | |
| residual = x | |
| out = self.relu(self.in1(self.conv1(x))) | |
| out = self.in2(self.conv2(out)) | |
| out = out + residual | |
| return out | |
| class UpsampleConvLayer(torch.nn.Module): | |
| """UpsampleConvLayer | |
| Upsamples the input and then does a convolution. This method gives better results | |
| compared to ConvTranspose2d. | |
| ref: http://distill.pub/2016/deconv-checkerboard/ | |
| """ | |
| def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None): | |
| super(UpsampleConvLayer, self).__init__() | |
| self.upsample = upsample | |
| reflection_padding = kernel_size // 2 | |
| self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) | |
| self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride) | |
| def forward(self, x): | |
| x_in = x | |
| if self.upsample: | |
| x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample) | |
| out = self.reflection_pad(x_in) | |
| out = self.conv2d(out) | |
| return out |