|
import os |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torchvision |
|
|
|
|
|
class VGGPerceptualLoss(torch.nn.Module): |
|
def __init__(self): |
|
super(VGGPerceptualLoss, self).__init__() |
|
os.environ['TORCH_HOME'] = os.path.abspath(os.getcwd()) |
|
blocks = [torchvision.models.vgg16().features[:4].eval(), |
|
torchvision.models.vgg16().features[4:9].eval(), |
|
torchvision.models.vgg16().features[9:16].eval(), |
|
torchvision.models.vgg16().features[16:23].eval()] |
|
for bl in blocks: |
|
for p in bl.parameters(): |
|
p.requires_grad = False |
|
self.blocks = torch.nn.ModuleList(blocks) |
|
|
|
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) |
|
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) |
|
|
|
def forward(self, x, y): |
|
x = x * 0.5 + 0.5 |
|
y = y * 0.5 + 0.5 |
|
x = (x - self.mean) / self.std |
|
y = (y - self.mean) / self.std |
|
|
|
x = F.interpolate(x, mode='bilinear', size=(224, 224), align_corners=False) |
|
y = F.interpolate(y, mode='bilinear', size=(224, 224), align_corners=False) |
|
perceptual_loss = 0.0 |
|
style_loss = 0.0 |
|
|
|
for i, block in enumerate(self.blocks): |
|
x = block(x) |
|
y = block(y) |
|
|
|
perceptual_loss += torch.nn.functional.l1_loss(x, y) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return perceptual_loss |
|
|