File size: 5,341 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
from collections import namedtuple

import torch
from torch import nn, distributed as dist
import torchvision.models as tv
from torch.distributed import barrier

from imaginaire.utils.distributed import is_local_master


def get_lpips_model():
    if dist.is_initialized() and not is_local_master():
        # Make sure only the first process in distributed training downloads the model, and the others use the cache.
        barrier()

    model = LPIPSNet().cuda()

    if dist.is_initialized() and is_local_master():
        # Make sure only the first process in distributed training downloads the model, and the others use the cache.
        barrier()
    return model


# Learned perceptual network, modified from https://github.com/richzhang/PerceptualSimilarity

def normalize_tensor(in_feat, eps=1e-5):
    norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True) + eps)
    return in_feat / (norm_factor + eps)


class NetLinLayer(nn.Module):
    """ A single linear layer used as placeholder for LPIPS learnt weights """

    def __init__(self, dim):
        super(NetLinLayer, self).__init__()
        self.weight = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, inp):
        out = self.weight * inp
        return out


class ScalingLayer(nn.Module):
    # For rescaling the input to vgg16
    def __init__(self):
        super(ScalingLayer, self).__init__()
        self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
        self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])

    def forward(self, inp):
        return (inp - self.shift) / self.scale


class LPIPSNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = LPNet()

    @torch.no_grad()
    def forward(self, fake_images, fake_images_another, align_corners=True):
        features, shape = self._forward_single(fake_images)
        features_another, _ = self._forward_single(fake_images_another)
        result = 0
        for i, g_feat in enumerate(features):
            cur_diff = torch.sum((g_feat - features_another[i]) ** 2, dim=1) / (shape[i] ** 2)
            result += cur_diff
        return result

    def _forward_single(self, images):
        return self.model(torch.clamp(images, 0, 1))


class LPNet(nn.Module):
    def __init__(self):
        super(LPNet, self).__init__()

        self.scaling_layer = ScalingLayer()
        self.net = vgg16(pretrained=True, requires_grad=False)
        self.L = 5
        dims = [64, 128, 256, 512, 512]
        self.lins = nn.ModuleList([NetLinLayer(dims[i]) for i in range(self.L)])

        weights = torch.hub.load_state_dict_from_url(
            'https://github.com/niopeng/CAM-Net/raw/main/code/models/weights/v0.1/vgg.pth'
        )
        for i in range(self.L):
            self.lins[i].weight.data = torch.sqrt(weights["lin%d.model.1.weight" % i])

    def forward(self, in0, avg=False):
        in0 = 2 * in0 - 1
        in0_input = self.scaling_layer(in0)
        outs0 = self.net.forward(in0_input)
        feats0 = {}
        shapes = []
        res = []

        for kk in range(self.L):
            feats0[kk] = normalize_tensor(outs0[kk])

        if avg:
            res = [self.lins[kk](feats0[kk]).mean([2, 3], keepdim=False) for kk in range(self.L)]
        else:
            for kk in range(self.L):
                cur_res = self.lins[kk](feats0[kk])
                shapes.append(cur_res.shape[-1])
                res.append(cur_res.reshape(cur_res.shape[0], -1))

        return res, shapes


class vgg16(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(vgg16, self).__init__()
        vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(23, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, x):
        h = self.slice1(x)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        h = self.slice5(h)
        h_relu5_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)

        return out