File size: 3,027 Bytes
97a6728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import tops
import sys
from contextlib import redirect_stdout
from torch_fidelity.sample_similarity_lpips import NetLinLayer, URL_VGG16_LPIPS, VGG16features, normalize_tensor, spatial_average


class SampleSimilarityLPIPS(torch.nn.Module):
    SUPPORTED_DTYPES = {
        'uint8': torch.uint8,
        'float32': torch.float32,
    }

    def __init__(self):

        super().__init__()
        self.chns = [64, 128, 256, 512, 512]
        self.L = len(self.chns)
        self.lin0 = NetLinLayer(self.chns[0], use_dropout=True)
        self.lin1 = NetLinLayer(self.chns[1], use_dropout=True)
        self.lin2 = NetLinLayer(self.chns[2], use_dropout=True)
        self.lin3 = NetLinLayer(self.chns[3], use_dropout=True)
        self.lin4 = NetLinLayer(self.chns[4], use_dropout=True)
        self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
        with redirect_stdout(sys.stderr):
            fp = tops.download_file(URL_VGG16_LPIPS)
            state_dict = torch.load(fp, map_location="cpu")
        self.load_state_dict(state_dict)
        self.net = VGG16features()
        self.eval()
        for param in self.parameters():
            param.requires_grad = False
        mean_rescaled = (1 + torch.tensor([-.030, -.088, -.188]).view(1, 3, 1, 1)) * 255 / 2
        inv_std_rescaled = 2 / (torch.tensor([.458, .448, .450]).view(1, 3, 1, 1) * 255)
        self.register_buffer("mean", mean_rescaled)
        self.register_buffer("std", inv_std_rescaled)

    def normalize(self, x):
        # torchvision values in range [0,1] mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]
        x = (x.float() - self.mean) * self.std
        return x

    @staticmethod
    def resize(x, size):
        if x.shape[-1] > size and x.shape[-2] > size:
            x = torch.nn.functional.interpolate(x, (size, size), mode='area')
        else:
            x = torch.nn.functional.interpolate(x, (size, size), mode='bilinear', align_corners=False)
        return x

    def lpips_from_feats(self, feats0, feats1):
        diffs = {}
        for kk in range(self.L):
            diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

        res = [spatial_average(self.lins[kk].model(diffs[kk])) for kk in range(self.L)]
        val = sum(res)
        return val

    def get_feats(self, x):
        assert x.dim() == 4 and x.shape[1] == 3, 'Input 0 is not Bx3xHxW'
        if x.shape[-2] < 16 or x.shape[-1] < 16:  # Resize images < 16x16
            f = 2
            size = tuple([int(f*_) for _ in x.shape[-2:]])
            x = torch.nn.functional.interpolate(x, size=size, mode="bilinear", align_corners=False)
        in0_input = self.normalize(x)
        outs0 = self.net.forward(in0_input)

        feats = {}
        for kk in range(self.L):
            feats[kk] = normalize_tensor(outs0[kk])
        return feats

    def forward(self, in0, in1):
        feats0 = self.get_feats(in0)
        feats1 = self.get_feats(in1)
        return self.lpips_from_feats(feats0, feats1), feats0, feats1