Spaces:
Runtime error
Runtime error
import torch | |
import tops | |
import sys | |
from contextlib import redirect_stdout | |
from torch_fidelity.sample_similarity_lpips import NetLinLayer, URL_VGG16_LPIPS, VGG16features, normalize_tensor, spatial_average | |
class SampleSimilarityLPIPS(torch.nn.Module): | |
SUPPORTED_DTYPES = { | |
'uint8': torch.uint8, | |
'float32': torch.float32, | |
} | |
def __init__(self): | |
super().__init__() | |
self.chns = [64, 128, 256, 512, 512] | |
self.L = len(self.chns) | |
self.lin0 = NetLinLayer(self.chns[0], use_dropout=True) | |
self.lin1 = NetLinLayer(self.chns[1], use_dropout=True) | |
self.lin2 = NetLinLayer(self.chns[2], use_dropout=True) | |
self.lin3 = NetLinLayer(self.chns[3], use_dropout=True) | |
self.lin4 = NetLinLayer(self.chns[4], use_dropout=True) | |
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] | |
with redirect_stdout(sys.stderr): | |
fp = tops.download_file(URL_VGG16_LPIPS) | |
state_dict = torch.load(fp, map_location="cpu") | |
self.load_state_dict(state_dict) | |
self.net = VGG16features() | |
self.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
mean_rescaled = (1 + torch.tensor([-.030, -.088, -.188]).view(1, 3, 1, 1)) * 255 / 2 | |
inv_std_rescaled = 2 / (torch.tensor([.458, .448, .450]).view(1, 3, 1, 1) * 255) | |
self.register_buffer("mean", mean_rescaled) | |
self.register_buffer("std", inv_std_rescaled) | |
def normalize(self, x): | |
# torchvision values in range [0,1] mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225] | |
x = (x.float() - self.mean) * self.std | |
return x | |
def resize(x, size): | |
if x.shape[-1] > size and x.shape[-2] > size: | |
x = torch.nn.functional.interpolate(x, (size, size), mode='area') | |
else: | |
x = torch.nn.functional.interpolate(x, (size, size), mode='bilinear', align_corners=False) | |
return x | |
def lpips_from_feats(self, feats0, feats1): | |
diffs = {} | |
for kk in range(self.L): | |
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 | |
res = [spatial_average(self.lins[kk].model(diffs[kk])) for kk in range(self.L)] | |
val = sum(res) | |
return val | |
def get_feats(self, x): | |
assert x.dim() == 4 and x.shape[1] == 3, 'Input 0 is not Bx3xHxW' | |
if x.shape[-2] < 16 or x.shape[-1] < 16: # Resize images < 16x16 | |
f = 2 | |
size = tuple([int(f*_) for _ in x.shape[-2:]]) | |
x = torch.nn.functional.interpolate(x, size=size, mode="bilinear", align_corners=False) | |
in0_input = self.normalize(x) | |
outs0 = self.net.forward(in0_input) | |
feats = {} | |
for kk in range(self.L): | |
feats[kk] = normalize_tensor(outs0[kk]) | |
return feats | |
def forward(self, in0, in1): | |
feats0 = self.get_feats(in0) | |
feats1 = self.get_feats(in1) | |
return self.lpips_from_feats(feats0, feats1), feats0, feats1 | |