menghanxia's picture
created the space
6e70c4a
raw
history blame
3.17 kB
import os
import torch.nn.functional as F
import torch
from utils.filters_tensor import GaussianSmoothing, bgr2gray
from utils import pytorch_ssim
from torch import nn
from .hourglass import HourGlass
from torchvision.models.vgg import vgg19
def l2_loss(y_input, y_target):
return F.mse_loss(y_input, y_target)
def l1_loss(y_input, y_target):
return F.l1_loss(y_input, y_target)
def gaussianL2(yInput, yTarget):
# data range [-1,1]
smoother = GaussianSmoothing(channels=1, kernel_size=11, sigma=2.0)
gaussianInput = smoother(yInput)
gaussianTarget = smoother(bgr2gray(yTarget))
return F.mse_loss(gaussianInput, gaussianTarget)
def binL1(yInput):
# data range is [-1,1]
return (yInput.abs() - 1.0).abs().mean()
def ssimLoss(yInput, yTarget):
# data range is [-1,1]
ssim = pytorch_ssim.ssim(yInput / 2. + 0.5, bgr2gray(yTarget / 2. + 0.5), window_size=11)
return 1. - ssim
class InverseHalf(nn.Module):
def __init__(self):
super(InverseHalf, self).__init__()
self.net = HourGlass(inChannel=1, outChannel=1)
def forward(self, x):
grayscale = self.net(x)
return grayscale
class FeatureLoss:
def __init__(self, pretrainedPath, requireGrad=False, multiGpu=True):
self.featureExactor = InverseHalf()
if multiGpu:
self.featureExactor = torch.nn.DataParallel(self.featureExactor).cuda()
print("-loading feature extractor: {} ...".format(pretrainedPath))
checkpoint = torch.load(pretrainedPath)
self.featureExactor.load_state_dict(checkpoint['state_dict'])
print("-feature network loaded")
if not requireGrad:
for param in self.featureExactor.parameters():
param.requires_grad = False
def __call__(self, yInput, yTarget):
inFeature = self.featureExactor(yInput)
return l2_loss(inFeature, yTarget)
class Vgg19Loss:
def __init__(self, multiGpu=True):
os.environ['TORCH_HOME']='~/bigdata/0ProgramS/checkpoints'
# data in BGR format, [0,1] range
self.mean = [0.485, 0.456, 0.406]
self.mean.reverse()
self.std = [0.229, 0.224, 0.225]
self.std.reverse()
vgg = vgg19(pretrained=True)
# maxpoll after conv4_4
self.featureExactor = nn.Sequential(*list(vgg.features)[:28]).eval()
for param in self.featureExactor.parameters():
param.requires_grad = False
if multiGpu:
self.featureExactor = torch.nn.DataParallel(self.featureExactor).cuda()
print('[*] Vgg19Loss init!')
def normalize(self, tensor):
tensor = tensor.clone()
mean = torch.as_tensor(self.mean, dtype=torch.float32, device=tensor.device)
std = torch.as_tensor(self.std, dtype=torch.float32, device=tensor.device)
tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
return tensor
def __call__(self, yInput, yTarget):
inFeature = self.featureExactor(self.normalize(yInput).flip(1))
targetFeature = self.featureExactor(self.normalize(yTarget).flip(1))
return l2_loss(inFeature, targetFeature)