Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from torchvision.models import VGG19_Weights, vgg19 | |
class VGG19: | |
""" | |
Custom version of VGG19 with the maxpool layers replaced with avgpool as per the paper | |
""" | |
def __init__(self, freeze_weights): | |
""" | |
If True, the gradients for the VGG params are turned off | |
""" | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.model = vgg19(weights=VGG19_Weights(VGG19_Weights.DEFAULT)).to(device) | |
# note: added one extra maxpool (layer 36) from the vgg... worked well so kept it in | |
self.output_layers = [0, 4, 9, 18, 27, 36] # vgg19 layers [convlayer1, maxpool, ..., maxpool] | |
for layer in self.output_layers[1:]: # convert the maxpool layers to an avgpool | |
self.model.features[layer] = nn.AvgPool2d(kernel_size=2, stride=2) | |
self.feature_maps = [] | |
for param in self.model.parameters(): | |
if freeze_weights: | |
param.requires_grad = False | |
else: | |
param.requires_grad = True | |
def __call__(self, x): | |
""" | |
Take in image, pass it through the VGG, capture feature maps at each of the output layers of VGG | |
""" | |
self.feature_maps = [] | |
for index, layer in enumerate(self.model.features): | |
# print(layer) | |
x = layer(x) # pass the img through the layer to get feature maps of the img | |
if index in self.output_layers: | |
self.feature_maps.append(x) | |
if index == self.output_layers[-1]: | |
# stop VGG execution as we've captured the feature maps from all the important layers | |
break | |
return self | |
def get_gram_matrices(self): | |
""" | |
Convert the featuremaps captured by the call method into gram matrices | |
""" | |
gram_matrices = [] | |
for fm in self.feature_maps: | |
n, x, y = fm.size() # num filters n and (filter dims x and y) | |
F = fm.reshape(n, x * y) # reshape filterbank into a 2D mat before doing auto correlation | |
gram_mat = (F @ F.t()) / (4. * n * x * y) # auto corr + normalize by layer output dims | |
gram_matrices.append(gram_mat) | |
return gram_matrices | |