sleepytaco's picture
initial commit
a4d851a
raw
history blame contribute delete
No virus
2.28 kB
import torch
from torch import nn
from torchvision.models import VGG19_Weights, vgg19
class VGG19:
"""
Custom version of VGG19 with the maxpool layers replaced with avgpool as per the paper
"""
def __init__(self, freeze_weights):
"""
If True, the gradients for the VGG params are turned off
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = vgg19(weights=VGG19_Weights(VGG19_Weights.DEFAULT)).to(device)
# note: added one extra maxpool (layer 36) from the vgg... worked well so kept it in
self.output_layers = [0, 4, 9, 18, 27, 36] # vgg19 layers [convlayer1, maxpool, ..., maxpool]
for layer in self.output_layers[1:]: # convert the maxpool layers to an avgpool
self.model.features[layer] = nn.AvgPool2d(kernel_size=2, stride=2)
self.feature_maps = []
for param in self.model.parameters():
if freeze_weights:
param.requires_grad = False
else:
param.requires_grad = True
def __call__(self, x):
"""
Take in image, pass it through the VGG, capture feature maps at each of the output layers of VGG
"""
self.feature_maps = []
for index, layer in enumerate(self.model.features):
# print(layer)
x = layer(x) # pass the img through the layer to get feature maps of the img
if index in self.output_layers:
self.feature_maps.append(x)
if index == self.output_layers[-1]:
# stop VGG execution as we've captured the feature maps from all the important layers
break
return self
def get_gram_matrices(self):
"""
Convert the featuremaps captured by the call method into gram matrices
"""
gram_matrices = []
for fm in self.feature_maps:
n, x, y = fm.size() # num filters n and (filter dims x and y)
F = fm.reshape(n, x * y) # reshape filterbank into a 2D mat before doing auto correlation
gram_mat = (F @ F.t()) / (4. * n * x * y) # auto corr + normalize by layer output dims
gram_matrices.append(gram_mat)
return gram_matrices