import numpy as np import torch import tops from dp2 import utils from torch_fidelity.helpers import get_kwarg, vassert from torch_fidelity.defaults import DEFAULTS as PPL_DEFAULTS from torch_fidelity.utils import sample_random, batch_interp, create_sample_similarity from torchvision.transforms.functional import resize def slerp(a, b, t): a = a / a.norm(dim=-1, keepdim=True) b = b / b.norm(dim=-1, keepdim=True) d = (a * b).sum(dim=-1, keepdim=True) p = t * torch.acos(d) c = b - d * a c = c / c.norm(dim=-1, keepdim=True) d = a * torch.cos(p) + c * torch.sin(p) d = d / d.norm(dim=-1, keepdim=True) return d @torch.no_grad() def calculate_ppl( dataloader, generator, latent_space=None, data_len=None, upsample_size=None, **kwargs) -> dict: """ Inspired by https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py """ if latent_space is None: latent_space = generator.latent_space assert latent_space in ["Z", "W"], f"Not supported latent space: {latent_space}" assert len(upsample_size) == 2 epsilon = PPL_DEFAULTS["ppl_epsilon"] interp = PPL_DEFAULTS['ppl_z_interp_mode'] similarity_name = PPL_DEFAULTS['ppl_sample_similarity'] sample_similarity_resize = PPL_DEFAULTS['ppl_sample_similarity_resize'] sample_similarity_dtype = PPL_DEFAULTS['ppl_sample_similarity_dtype'] discard_percentile_lower = PPL_DEFAULTS['ppl_discard_percentile_lower'] discard_percentile_higher = PPL_DEFAULTS['ppl_discard_percentile_higher'] vassert(type(epsilon) is float and epsilon > 0, 'Epsilon must be a small positive floating point number') vassert(discard_percentile_lower is None or 0 < discard_percentile_lower < 100, 'Invalid percentile') vassert(discard_percentile_higher is None or 0 < discard_percentile_higher < 100, 'Invalid percentile') if discard_percentile_lower is not None and discard_percentile_higher is not None: vassert(0 < discard_percentile_lower < discard_percentile_higher < 100, 'Invalid percentiles') sample_similarity = create_sample_similarity( similarity_name, sample_similarity_resize=sample_similarity_resize, sample_similarity_dtype=sample_similarity_dtype, cuda=False, **kwargs ) sample_similarity = tops.to_cuda(sample_similarity) rng = np.random.RandomState(get_kwarg('rng_seed', kwargs)) distances = [] if data_len is None: data_len = len(dataloader) * dataloader.batch_size z0 = sample_random(rng, (data_len, generator.z_channels), "normal") z1 = sample_random(rng, (data_len, generator.z_channels), "normal") if latent_space == "Z": z1 = batch_interp(z0, z1, epsilon, interp) print("Computing PPL IN", latent_space) distances = torch.zeros(data_len, dtype=torch.float32, device=tops.get_device()) print(distances.shape) end = 0 n_samples = 0 for it, batch in enumerate(utils.tqdm_(dataloader, desc="Perceptual Path Length")): start = end end = start + batch["img"].shape[0] n_samples += batch["img"].shape[0] batch_lat_e0 = tops.to_cuda(z0[start:end]) batch_lat_e1 = tops.to_cuda(z1[start:end]) if latent_space == "W": w0 = generator.get_w(batch_lat_e0, update_emas=False) w1 = generator.get_w(batch_lat_e1, update_emas=False) w1 = w0.lerp(w1, epsilon) # PPL end rgb1 = generator(**batch, w=w0)["img"] rgb2 = generator(**batch, w=w1)["img"] else: rgb1 = generator(**batch, z=batch_lat_e0)["img"] rgb2 = generator(**batch, z=batch_lat_e1)["img"] if rgb1.shape[-2] < upsample_size[0] or rgb1.shape[-1] < upsample_size[1]: rgb1 = resize(rgb1, upsample_size, antialias=True) rgb2 = resize(rgb2, upsample_size, antialias=True) rgb1 = utils.denormalize_img(rgb1).mul(255).byte() rgb2 = utils.denormalize_img(rgb2).mul(255).byte() sim = sample_similarity(rgb1, rgb2) dist_lat_e01 = sim / (epsilon ** 2) distances[start:end] = dist_lat_e01.view(-1) distances = distances[:n_samples] distances = tops.all_gather_uneven(distances).cpu().numpy() if tops.rank() != 0: return {"ppl/mean": -1, "ppl/std": -1} if tops.rank() == 0: cond, lo, hi = None, None, None if discard_percentile_lower is not None: lo = np.percentile(distances, discard_percentile_lower, interpolation='lower') cond = lo <= distances if discard_percentile_higher is not None: hi = np.percentile(distances, discard_percentile_higher, interpolation='higher') cond = np.logical_and(cond, distances <= hi) if cond is not None: distances = np.extract(cond, distances) return { "ppl/mean": float(np.mean(distances)), "ppl/std": float(np.std(distances)), } else: return {"ppl/mean"}