|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
""" |
|
VGG implementation from [InterDigitalInc](https://github.com/InterDigitalInc/HRFAE/blob/master/nets.py) |
|
""" |
|
|
|
class VGG(nn.Module): |
|
def __init__(self, pool='max'): |
|
super(VGG, self).__init__() |
|
|
|
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) |
|
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) |
|
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) |
|
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) |
|
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) |
|
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) |
|
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) |
|
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) |
|
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) |
|
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) |
|
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) |
|
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) |
|
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) |
|
self.fc6 = nn.Linear(25088, 4096, bias=True) |
|
self.fc7 = nn.Linear(4096, 4096, bias=True) |
|
self.fc8_101 = nn.Linear(4096, 101, bias=True) |
|
if pool == 'max': |
|
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
elif pool == 'avg': |
|
self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) |
|
self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) |
|
self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2) |
|
self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2) |
|
self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2) |
|
|
|
def forward(self, x): |
|
out = {} |
|
out['r11'] = F.relu(self.conv1_1(x)) |
|
out['r12'] = F.relu(self.conv1_2(out['r11'])) |
|
out['p1'] = self.pool1(out['r12']) |
|
out['r21'] = F.relu(self.conv2_1(out['p1'])) |
|
out['r22'] = F.relu(self.conv2_2(out['r21'])) |
|
out['p2'] = self.pool2(out['r22']) |
|
out['r31'] = F.relu(self.conv3_1(out['p2'])) |
|
out['r32'] = F.relu(self.conv3_2(out['r31'])) |
|
out['r33'] = F.relu(self.conv3_3(out['r32'])) |
|
out['p3'] = self.pool3(out['r33']) |
|
out['r41'] = F.relu(self.conv4_1(out['p3'])) |
|
out['r42'] = F.relu(self.conv4_2(out['r41'])) |
|
out['r43'] = F.relu(self.conv4_3(out['r42'])) |
|
out['p4'] = self.pool4(out['r43']) |
|
out['r51'] = F.relu(self.conv5_1(out['p4'])) |
|
out['r52'] = F.relu(self.conv5_2(out['r51'])) |
|
out['r53'] = F.relu(self.conv5_3(out['r52'])) |
|
out['p5'] = self.pool5(out['r53']) |
|
out['p5'] = out['p5'].view(out['p5'].size(0), -1) |
|
out['fc6'] = F.relu(self.fc6(out['p5'])) |
|
out['fc7'] = F.relu(self.fc7(out['fc6'])) |
|
out['fc8'] = self.fc8_101(out['fc7']) |
|
return out |