import torch import torch.nn as nn import torchvision class VGG16(nn.Module): def __init__(self): super(VGG16, self).__init__() vgg16 = torchvision.models.vgg16(pretrained=True) self.enc_1 = nn.Sequential(*vgg16.features[:5]) self.enc_2 = nn.Sequential(*vgg16.features[5:10]) self.enc_3 = nn.Sequential(*vgg16.features[10:17]) for i in range(3): for param in getattr(self, f'enc_{i+1:d}').parameters(): param.requires_grad = False def forward(self, image): results = [image] for i in range(3): func = getattr(self, f'enc_{i+1:d}') results.append(func(results[-1])) return results[1:] class ContentPerceptualLoss(nn.Module): def __init__(self): super().__init__() self.VGG = VGG16() def calculate_loss(self, generated_images, target_images, device): self.VGG = self.VGG.to(device) generated_features = self.VGG(generated_images) target_features = self.VGG(target_images) perceptual_loss = 0 perceptual_loss += torch.mean((target_features[0] - generated_features[0]) ** 2) perceptual_loss += torch.mean((target_features[1] - generated_features[1]) ** 2) perceptual_loss += torch.mean((target_features[2] - generated_features[2]) ** 2) perceptual_loss /= 3 return perceptual_loss