Spaces:
Runtime error
Runtime error
File size: 5,906 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import math
import os
from pathlib import Path
from cleanfid.inception_torchscript import InceptionV3W
import clip
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from tqdm.auto import trange
from . import utils
class InceptionV3FeatureExtractor(nn.Module):
def __init__(self, device='cpu'):
super().__init__()
path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion'
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4'
utils.download_file(path / 'inception-2015-12-05.pt', url, digest)
self.model = InceptionV3W(str(path), resize_inside=False).to(device)
self.size = (299, 299)
def forward(self, x):
x = F.interpolate(x, self.size, mode='bicubic', align_corners=False, antialias=True)
if x.shape[1] == 1:
x = torch.cat([x] * 3, dim=1)
x = (x * 127.5 + 127.5).clamp(0, 255)
return self.model(x)
class CLIPFeatureExtractor(nn.Module):
def __init__(self, name='ViT-B/16', device='cpu'):
super().__init__()
self.model = clip.load(name, device=device)[0].eval().requires_grad_(False)
self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711))
self.size = self.model.visual.input_resolution, self.model.visual.input_resolution
@classmethod
def available_models(cls):
return clip.available_models()
def forward(self, x):
x = (x + 1) / 2
x = F.interpolate(x, self.size, mode='bicubic', align_corners=False, antialias=True)
if x.shape[1] == 1:
x = torch.cat([x] * 3, dim=1)
x = self.normalize(x)
x = self.model.encode_image(x).float()
x = F.normalize(x) * x.shape[-1] ** 0.5
return x
class DINOv2FeatureExtractor(nn.Module):
def __init__(self, name='vitl14', device='cpu'):
super().__init__()
self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_' + name).to(device).eval().requires_grad_(False)
self.normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
self.size = 224, 224
@classmethod
def available_models(cls):
return ['vits14', 'vitb14', 'vitl14', 'vitg14']
def forward(self, x):
x = (x + 1) / 2
x = F.interpolate(x, self.size, mode='bicubic', align_corners=False, antialias=True)
if x.shape[1] == 1:
x = torch.cat([x] * 3, dim=1)
x = self.normalize(x)
with torch.cuda.amp.autocast(dtype=torch.float16):
x = self.model(x).float()
x = F.normalize(x) * x.shape[-1] ** 0.5
return x
def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size):
n_per_proc = math.ceil(n / accelerator.num_processes)
feats_all = []
try:
for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process):
cur_batch_size = min(n - i, batch_size)
samples = sample_fn(cur_batch_size)[:cur_batch_size]
feats_all.append(accelerator.gather(extractor_fn(samples)))
except StopIteration:
pass
return torch.cat(feats_all)[:n]
def polynomial_kernel(x, y):
d = x.shape[-1]
dot = x @ y.transpose(-2, -1)
return (dot / d + 1) ** 3
def squared_mmd(x, y, kernel=polynomial_kernel):
m = x.shape[-2]
n = y.shape[-2]
kxx = kernel(x, x)
kyy = kernel(y, y)
kxy = kernel(x, y)
kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1)
kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1)
kxy_sum = kxy.sum([-1, -2])
term_1 = kxx_sum / m / (m - 1)
term_2 = kyy_sum / n / (n - 1)
term_3 = kxy_sum * 2 / m / n
return term_1 + term_2 - term_3
@utils.tf32_mode(matmul=False)
def kid(x, y, max_size=5000):
x_size, y_size = x.shape[0], y.shape[0]
n_partitions = math.ceil(max(x_size / max_size, y_size / max_size))
total_mmd = x.new_zeros([])
for i in range(n_partitions):
cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)]
cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)]
total_mmd = total_mmd + squared_mmd(cur_x, cur_y)
return total_mmd / n_partitions
class _MatrixSquareRootEig(torch.autograd.Function):
@staticmethod
def forward(ctx, a):
vals, vecs = torch.linalg.eigh(a)
ctx.save_for_backward(vals, vecs)
return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1)
@staticmethod
def backward(ctx, grad_output):
vals, vecs = ctx.saved_tensors
d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1)
vecs_t = vecs.transpose(-2, -1)
return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t
def sqrtm_eig(a):
if a.ndim < 2:
raise RuntimeError('tensor of matrices must have at least 2 dimensions')
if a.shape[-2] != a.shape[-1]:
raise RuntimeError('tensor must be batches of square matrices')
return _MatrixSquareRootEig.apply(a)
@utils.tf32_mode(matmul=False)
def fid(x, y, eps=1e-8):
x_mean = x.mean(dim=0)
y_mean = y.mean(dim=0)
mean_term = (x_mean - y_mean).pow(2).sum()
x_cov = torch.cov(x.T)
y_cov = torch.cov(y.T)
eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps
x_cov = x_cov + eps_eye
y_cov = y_cov + eps_eye
x_cov_sqrt = sqrtm_eig(x_cov)
cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt))
return mean_term + cov_term
|