Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torchvision | |
class VGG16(nn.Module): | |
def __init__(self): | |
super(VGG16, self).__init__() | |
vgg16 = torchvision.models.vgg16(pretrained=True) | |
self.enc_1 = nn.Sequential(*vgg16.features[:5]) | |
self.enc_2 = nn.Sequential(*vgg16.features[5:10]) | |
self.enc_3 = nn.Sequential(*vgg16.features[10:17]) | |
for i in range(3): | |
for param in getattr(self, f'enc_{i+1:d}').parameters(): | |
param.requires_grad = False | |
def forward(self, image): | |
results = [image] | |
for i in range(3): | |
func = getattr(self, f'enc_{i+1:d}') | |
results.append(func(results[-1])) | |
return results[1:] | |
class ContentPerceptualLoss(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.VGG = VGG16() | |
def calculate_loss(self, generated_images, target_images, device): | |
self.VGG = self.VGG.to(device) | |
generated_features = self.VGG(generated_images) | |
target_features = self.VGG(target_images) | |
perceptual_loss = 0 | |
perceptual_loss += torch.mean((target_features[0] - generated_features[0]) ** 2) | |
perceptual_loss += torch.mean((target_features[1] - generated_features[1]) ** 2) | |
perceptual_loss += torch.mean((target_features[2] - generated_features[2]) ** 2) | |
perceptual_loss /= 3 | |
return perceptual_loss | |