import torch.nn as nn import torch.nn.functional as F """ VGG implementation from [InterDigitalInc](https://github.com/InterDigitalInc/HRFAE/blob/master/nets.py) """ class VGG(nn.Module): def __init__(self, pool='max'): super(VGG, self).__init__() # vgg modules self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.fc6 = nn.Linear(25088, 4096, bias=True) self.fc7 = nn.Linear(4096, 4096, bias=True) self.fc8_101 = nn.Linear(4096, 101, bias=True) if pool == 'max': self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) elif pool == 'avg': self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2) self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2) self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2) def forward(self, x): out = {} out['r11'] = F.relu(self.conv1_1(x)) out['r12'] = F.relu(self.conv1_2(out['r11'])) out['p1'] = self.pool1(out['r12']) out['r21'] = F.relu(self.conv2_1(out['p1'])) out['r22'] = F.relu(self.conv2_2(out['r21'])) out['p2'] = self.pool2(out['r22']) out['r31'] = F.relu(self.conv3_1(out['p2'])) out['r32'] = F.relu(self.conv3_2(out['r31'])) out['r33'] = F.relu(self.conv3_3(out['r32'])) out['p3'] = self.pool3(out['r33']) out['r41'] = F.relu(self.conv4_1(out['p3'])) out['r42'] = F.relu(self.conv4_2(out['r41'])) out['r43'] = F.relu(self.conv4_3(out['r42'])) out['p4'] = self.pool4(out['r43']) out['r51'] = F.relu(self.conv5_1(out['p4'])) out['r52'] = F.relu(self.conv5_2(out['r51'])) out['r53'] = F.relu(self.conv5_3(out['r52'])) out['p5'] = self.pool5(out['r53']) out['p5'] = out['p5'].view(out['p5'].size(0), -1) out['fc6'] = F.relu(self.fc6(out['p5'])) out['fc7'] = F.relu(self.fc7(out['fc6'])) out['fc8'] = self.fc8_101(out['fc7']) return out