AutoLink / utils_ /loss.py
xingzhehe's picture
try fitst commit
91fc62a
import os
import torch
import torch.nn.functional as F
import torchvision
class VGGPerceptualLoss(torch.nn.Module):
def __init__(self):
super(VGGPerceptualLoss, self).__init__()
os.environ['TORCH_HOME'] = os.path.abspath(os.getcwd())
blocks = [torchvision.models.vgg16().features[:4].eval(),
torchvision.models.vgg16().features[4:9].eval(),
torchvision.models.vgg16().features[9:16].eval(),
torchvision.models.vgg16().features[16:23].eval()]
for bl in blocks:
for p in bl.parameters():
p.requires_grad = False
self.blocks = torch.nn.ModuleList(blocks)
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, x, y):
x = x * 0.5 + 0.5
y = y * 0.5 + 0.5
x = (x - self.mean) / self.std
y = (y - self.mean) / self.std
x = F.interpolate(x, mode='bilinear', size=(224, 224), align_corners=False)
y = F.interpolate(y, mode='bilinear', size=(224, 224), align_corners=False)
perceptual_loss = 0.0
style_loss = 0.0
for i, block in enumerate(self.blocks):
x = block(x)
y = block(y)
perceptual_loss += torch.nn.functional.l1_loss(x, y)
# b, ch, h, w = x.shape
# act_x = x.reshape(x.shape[0], x.shape[1], -1)
# act_y = y.reshape(y.shape[0], y.shape[1], -1)
# gram_x = act_x @ act_x.permute(0, 2, 1) / (ch * h * w)
# gram_y = act_y @ act_y.permute(0, 2, 1) / (ch * h * w)
# style_loss += torch.nn.functional.l1_loss(gram_x, gram_y)
return perceptual_loss#, style_loss