haakohu's picture
fix
44539fc
raw
history blame contribute delete
No virus
3.03 kB
import torch
import tops
import sys
from contextlib import redirect_stdout
from torch_fidelity.sample_similarity_lpips import NetLinLayer, URL_VGG16_LPIPS, VGG16features, normalize_tensor, spatial_average
class SampleSimilarityLPIPS(torch.nn.Module):
SUPPORTED_DTYPES = {
'uint8': torch.uint8,
'float32': torch.float32,
}
def __init__(self):
super().__init__()
self.chns = [64, 128, 256, 512, 512]
self.L = len(self.chns)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=True)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=True)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=True)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=True)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=True)
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
with redirect_stdout(sys.stderr):
fp = tops.download_file(URL_VGG16_LPIPS)
state_dict = torch.load(fp, map_location="cpu")
self.load_state_dict(state_dict)
self.net = VGG16features()
self.eval()
for param in self.parameters():
param.requires_grad = False
mean_rescaled = (1 + torch.tensor([-.030, -.088, -.188]).view(1, 3, 1, 1)) * 255 / 2
inv_std_rescaled = 2 / (torch.tensor([.458, .448, .450]).view(1, 3, 1, 1) * 255)
self.register_buffer("mean", mean_rescaled)
self.register_buffer("std", inv_std_rescaled)
def normalize(self, x):
# torchvision values in range [0,1] mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]
x = (x.float() - self.mean) * self.std
return x
@staticmethod
def resize(x, size):
if x.shape[-1] > size and x.shape[-2] > size:
x = torch.nn.functional.interpolate(x, (size, size), mode='area')
else:
x = torch.nn.functional.interpolate(x, (size, size), mode='bilinear', align_corners=False)
return x
def lpips_from_feats(self, feats0, feats1):
diffs = {}
for kk in range(self.L):
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
res = [spatial_average(self.lins[kk].model(diffs[kk])) for kk in range(self.L)]
val = sum(res)
return val
def get_feats(self, x):
assert x.dim() == 4 and x.shape[1] == 3, 'Input 0 is not Bx3xHxW'
if x.shape[-2] < 16 or x.shape[-1] < 16: # Resize images < 16x16
f = 2
size = tuple([int(f*_) for _ in x.shape[-2:]])
x = torch.nn.functional.interpolate(x, size=size, mode="bilinear", align_corners=False)
in0_input = self.normalize(x)
outs0 = self.net.forward(in0_input)
feats = {}
for kk in range(self.L):
feats[kk] = normalize_tensor(outs0[kk])
return feats
def forward(self, in0, in1):
feats0 = self.get_feats(in0)
feats1 = self.get_feats(in1)
return self.lpips_from_feats(feats0, feats1), feats0, feats1