File size: 2,294 Bytes
9eae6e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn as nn
import torch.nn.functional as F

from .vgg import VGG19, VGG16

class Perceptual16Loss(nn.Module):
    def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
        super(Perceptual16Loss, self).__init__()
        self.vgg = VGG16()
        self.criterion = torch.nn.L1Loss()
        self.weights = weights

    def calculate_pl(self, x, y):
        feat_output = self.vgg(x)
        feat_gt = self.vgg(y)

        content_loss = 0.0

        for i in range(3):
            content_loss += self.criterion(feat_output[i], feat_gt[i])
        return content_loss.to(device=x.device)
    
    def compute_gram(self, x):
        b, c, h, w = x.size()
        f = x.view(b, c, w * h)
        f_T = f.transpose(1, 2)
        G = f.bmm(f_T) / (h * w * c)
        return G

    def calc_style(self, x, y):
        feat_output = self.extractor(x)
        feat_gt = self.extractor(y)

        style_loss = 0.0

        for i in range(3):
            style_loss += self.criterion(
                self.compute_gram(feat_output[i]), self.compute_gram(feat_gt[i]))
        return style_loss

class Perceptual19Loss(nn.Module):
    def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
        super(Perceptual19Loss, self).__init__()
        self.vgg = VGG19()
        self.criterion = torch.nn.L1Loss()
        self.weights = weights

    def calculate_pl(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        content_loss = 0.0
        prefix = [1, 2, 3, 4, 5]
        for i in range(5):
            content_loss += self.weights[i] * self.criterion(
                x_vgg[f'relu{prefix[i]}_1'], y_vgg[f'relu{prefix[i]}_1'])
        return content_loss.to(device=x.device)
    
    def compute_gram(self, x):
        b, c, h, w = x.size()
        f = x.view(b, c, w * h)
        f_T = f.transpose(1, 2)
        G = f.bmm(f_T) / (h * w * c)
        return G

    def calc_style(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        style_loss = 0.0
        prefix = [2, 3, 4, 5]
        posfix = [2, 4, 4, 2]
        for pre, pos in list(zip(prefix, posfix)):
            style_loss += self.criterion(
                self.compute_gram(x_vgg[f'relu{pre}_{pos}']), self.compute_gram(y_vgg[f'relu{pre}_{pos}']))
        return style_loss