""" Copyright (C) 2018 NVIDIA Corporation. All rights reserved. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). """ import torch.nn as nn class VGGEncoder(nn.Module): def __init__(self, level): super(VGGEncoder, self).__init__() self.level = level # 224 x 224 self.conv0 = nn.Conv2d(3, 3, 1, 1, 0) self.pad1_1 = nn.ReflectionPad2d((1, 1, 1, 1)) # 226 x 226 self.conv1_1 = nn.Conv2d(3, 64, 3, 1, 0) self.relu1_1 = nn.ReLU(inplace=True) # 224 x 224 if level < 2: return self.pad1_2 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 0) self.relu1_2 = nn.ReLU(inplace=True) # 224 x 224 self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) # 112 x 112 self.pad2_1 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 0) self.relu2_1 = nn.ReLU(inplace=True) # 112 x 112 if level < 3: return self.pad2_2 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 0) self.relu2_2 = nn.ReLU(inplace=True) # 112 x 112 self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) # 56 x 56 self.pad3_1 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 0) self.relu3_1 = nn.ReLU(inplace=True) # 56 x 56 if level < 4: return self.pad3_2 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 0) self.relu3_2 = nn.ReLU(inplace=True) # 56 x 56 self.pad3_3 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 0) self.relu3_3 = nn.ReLU(inplace=True) # 56 x 56 self.pad3_4 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv3_4 = nn.Conv2d(256, 256, 3, 1, 0) self.relu3_4 = nn.ReLU(inplace=True) # 56 x 56 self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) # 28 x 28 self.pad4_1 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 0) self.relu4_1 = nn.ReLU(inplace=True) # 28 x 28 def forward(self, x): out = self.conv0(x) out = self.pad1_1(out) out = self.conv1_1(out) out = self.relu1_1(out) if self.level < 2: return out out = self.pad1_2(out) out = self.conv1_2(out) pool1 = self.relu1_2(out) out, pool1_idx = self.maxpool1(pool1) out = self.pad2_1(out) out = self.conv2_1(out) out = self.relu2_1(out) if self.level < 3: return out, pool1_idx, pool1.size() out = self.pad2_2(out) out = self.conv2_2(out) pool2 = self.relu2_2(out) out, pool2_idx = self.maxpool2(pool2) out = self.pad3_1(out) out = self.conv3_1(out) out = self.relu3_1(out) if self.level < 4: return out, pool1_idx, pool1.size(), pool2_idx, pool2.size() out = self.pad3_2(out) out = self.conv3_2(out) out = self.relu3_2(out) out = self.pad3_3(out) out = self.conv3_3(out) out = self.relu3_3(out) out = self.pad3_4(out) out = self.conv3_4(out) pool3 = self.relu3_4(out) out, pool3_idx = self.maxpool3(pool3) out = self.pad4_1(out) out = self.conv4_1(out) out = self.relu4_1(out) return out, pool1_idx, pool1.size(), pool2_idx, pool2.size(), pool3_idx, pool3.size() def forward_multiple(self, x): out = self.conv0(x) out = self.pad1_1(out) out = self.conv1_1(out) out = self.relu1_1(out) if self.level < 2: return out out1 = out out = self.pad1_2(out) out = self.conv1_2(out) pool1 = self.relu1_2(out) out, pool1_idx = self.maxpool1(pool1) out = self.pad2_1(out) out = self.conv2_1(out) out = self.relu2_1(out) if self.level < 3: return out, out1 out2 = out out = self.pad2_2(out) out = self.conv2_2(out) pool2 = self.relu2_2(out) out, pool2_idx = self.maxpool2(pool2) out = self.pad3_1(out) out = self.conv3_1(out) out = self.relu3_1(out) if self.level < 4: return out, out2, out1 out3 = out out = self.pad3_2(out) out = self.conv3_2(out) out = self.relu3_2(out) out = self.pad3_3(out) out = self.conv3_3(out) out = self.relu3_3(out) out = self.pad3_4(out) out = self.conv3_4(out) pool3 = self.relu3_4(out) out, pool3_idx = self.maxpool3(pool3) out = self.pad4_1(out) out = self.conv4_1(out) out = self.relu4_1(out) return out, out3, out2, out1 class VGGDecoder(nn.Module): def __init__(self, level): super(VGGDecoder, self).__init__() self.level = level if level > 3: self.pad4_1 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv4_1 = nn.Conv2d(512, 256, 3, 1, 0) self.relu4_1 = nn.ReLU(inplace=True) # 28 x 28 self.unpool3 = nn.MaxUnpool2d(kernel_size=2, stride=2) # 56 x 56 self.pad3_4 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv3_4 = nn.Conv2d(256, 256, 3, 1, 0) self.relu3_4 = nn.ReLU(inplace=True) # 56 x 56 self.pad3_3 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 0) self.relu3_3 = nn.ReLU(inplace=True) # 56 x 56 self.pad3_2 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 0) self.relu3_2 = nn.ReLU(inplace=True) # 56 x 56 if level > 2: self.pad3_1 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv3_1 = nn.Conv2d(256, 128, 3, 1, 0) self.relu3_1 = nn.ReLU(inplace=True) # 56 x 56 self.unpool2 = nn.MaxUnpool2d(kernel_size=2, stride=2) # 112 x 112 self.pad2_2 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 0) self.relu2_2 = nn.ReLU(inplace=True) # 112 x 112 if level > 1: self.pad2_1 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv2_1 = nn.Conv2d(128, 64, 3, 1, 0) self.relu2_1 = nn.ReLU(inplace=True) # 112 x 112 self.unpool1 = nn.MaxUnpool2d(kernel_size=2, stride=2) # 224 x 224 self.pad1_2 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 0) self.relu1_2 = nn.ReLU(inplace=True) # 224 x 224 if level > 0: self.pad1_1 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv1_1 = nn.Conv2d(64, 3, 3, 1, 0) def forward(self, x, pool1_idx=None, pool1_size=None, pool2_idx=None, pool2_size=None, pool3_idx=None, pool3_size=None): out = x if self.level > 3: out = self.pad4_1(out) out = self.conv4_1(out) out = self.relu4_1(out) out = self.unpool3(out, pool3_idx, output_size=pool3_size) out = self.pad3_4(out) out = self.conv3_4(out) out = self.relu3_4(out) out = self.pad3_3(out) out = self.conv3_3(out) out = self.relu3_3(out) out = self.pad3_2(out) out = self.conv3_2(out) out = self.relu3_2(out) if self.level > 2: out = self.pad3_1(out) out = self.conv3_1(out) out = self.relu3_1(out) out = self.unpool2(out, pool2_idx, output_size=pool2_size) out = self.pad2_2(out) out = self.conv2_2(out) out = self.relu2_2(out) if self.level > 1: out = self.pad2_1(out) out = self.conv2_1(out) out = self.relu2_1(out) out = self.unpool1(out, pool1_idx, output_size=pool1_size) out = self.pad1_2(out) out = self.conv1_2(out) out = self.relu1_2(out) if self.level > 0: out = self.pad1_1(out) out = self.conv1_1(out) return out