File size: 3,173 Bytes
6e70c4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import torch.nn.functional as F
import torch
from utils.filters_tensor import GaussianSmoothing, bgr2gray
from utils import pytorch_ssim
from torch import nn
from .hourglass import HourGlass
from torchvision.models.vgg import vgg19


def l2_loss(y_input, y_target):
    return F.mse_loss(y_input, y_target)


def l1_loss(y_input, y_target):
    return F.l1_loss(y_input, y_target)


def gaussianL2(yInput, yTarget):
    # data range [-1,1]
    smoother = GaussianSmoothing(channels=1, kernel_size=11, sigma=2.0)
    gaussianInput = smoother(yInput)
    gaussianTarget = smoother(bgr2gray(yTarget))
    return F.mse_loss(gaussianInput, gaussianTarget)


def binL1(yInput):
    # data range is [-1,1]
    return (yInput.abs() - 1.0).abs().mean()


def ssimLoss(yInput, yTarget):
    # data range is [-1,1]
    ssim = pytorch_ssim.ssim(yInput / 2. + 0.5, bgr2gray(yTarget / 2. + 0.5), window_size=11)
    return 1. - ssim


class InverseHalf(nn.Module):
    def __init__(self):
        super(InverseHalf, self).__init__()
        self.net = HourGlass(inChannel=1, outChannel=1)

    def forward(self, x):
        grayscale = self.net(x)
        return grayscale


class FeatureLoss:
    def __init__(self, pretrainedPath, requireGrad=False, multiGpu=True):
        self.featureExactor = InverseHalf()
        if multiGpu:
            self.featureExactor = torch.nn.DataParallel(self.featureExactor).cuda()
        print("-loading feature extractor: {} ...".format(pretrainedPath))
        checkpoint = torch.load(pretrainedPath)
        self.featureExactor.load_state_dict(checkpoint['state_dict'])
        print("-feature network loaded")
        if not requireGrad:
            for param in self.featureExactor.parameters():
                param.requires_grad = False

    def __call__(self, yInput, yTarget):
        inFeature = self.featureExactor(yInput)
        return l2_loss(inFeature, yTarget)


class Vgg19Loss:
    def __init__(self, multiGpu=True):
        os.environ['TORCH_HOME']='~/bigdata/0ProgramS/checkpoints'
        # data in BGR format, [0,1] range
        self.mean = [0.485, 0.456, 0.406]
        self.mean.reverse()
        self.std = [0.229, 0.224, 0.225]
        self.std.reverse()
        vgg = vgg19(pretrained=True)
        # maxpoll after conv4_4
        self.featureExactor = nn.Sequential(*list(vgg.features)[:28]).eval()
        for param in self.featureExactor.parameters():
            param.requires_grad = False
        if multiGpu:
            self.featureExactor = torch.nn.DataParallel(self.featureExactor).cuda()
        print('[*] Vgg19Loss init!')

    def normalize(self, tensor):
        tensor = tensor.clone()
        mean = torch.as_tensor(self.mean, dtype=torch.float32, device=tensor.device)
        std = torch.as_tensor(self.std, dtype=torch.float32, device=tensor.device)
        tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
        return tensor

    def __call__(self, yInput, yTarget):
        inFeature = self.featureExactor(self.normalize(yInput).flip(1))
        targetFeature = self.featureExactor(self.normalize(yTarget).flip(1))
        return l2_loss(inFeature, targetFeature)