Spaces:
Build error
Build error
import jax | |
import jax.numpy as jnp | |
import flax | |
import flax.linen as nn | |
import numpy as np | |
import os | |
import functools | |
import argparse | |
import scipy | |
from tqdm import tqdm | |
import logging | |
from . import inception | |
from . import utils | |
logger = logging.getLogger(__name__) | |
class FID: | |
def __init__(self, generator, dataset, config, use_cache=True, truncation_psi=1.0): | |
""" | |
Evaluates the FID score for a given generator and a given dataset. | |
Implementation mostly taken from https://github.com/matthias-wright/jax-fid | |
Reference: https://arxiv.org/abs/1706.08500 | |
Args: | |
generator (nn.Module): Generator network. | |
dataset (tf.data.Dataset): Dataset containing the real images. | |
config (argparse.Namespace): Configuration. | |
use_cache (bool): If True, only compute the activation stats once for the real images and store them. | |
truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled. | |
""" | |
self.num_images = config.num_fid_images | |
self.batch_size = config.batch_size | |
self.c_dim = config.c_dim | |
self.z_dim = config.z_dim | |
self.dataset = dataset | |
self.num_devices = jax.device_count() | |
self.num_local_devices = jax.local_device_count() | |
self.use_cache = use_cache | |
if self.use_cache: | |
self.cache = {} | |
rng = jax.random.PRNGKey(0) | |
inception_net = inception.InceptionV3(pretrained=True) | |
self.inception_params = inception_net.init(rng, jnp.ones((1, config.resolution, config.resolution, 3))) | |
self.inception_params = flax.jax_utils.replicate(self.inception_params) | |
#self.inception = jax.jit(functools.partial(model.apply, train=False)) | |
self.inception_apply = jax.pmap(functools.partial(inception_net.apply, train=False), axis_name='batch') | |
self.generator_apply = jax.pmap(functools.partial(generator.apply, truncation_psi=truncation_psi, train=False, noise_mode='const'), axis_name='batch') | |
def compute_fid(self, generator_params, seed_offset=0): | |
generator_params = flax.jax_utils.replicate(generator_params) | |
mu_real, sigma_real = self.compute_stats_for_dataset() | |
mu_fake, sigma_fake = self.compute_stats_for_generator(generator_params, seed_offset) | |
fid_score = self.compute_frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake, eps=1e-6) | |
return fid_score | |
def compute_frechet_distance(self, mu1, mu2, sigma1, sigma2, eps=1e-6): | |
# Taken from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py | |
mu1 = np.atleast_1d(mu1) | |
mu2 = np.atleast_1d(mu2) | |
sigma1 = np.atleast_1d(sigma1) | |
sigma2 = np.atleast_1d(sigma2) | |
assert mu1.shape == mu2.shape | |
assert sigma1.shape == sigma2.shape | |
diff = mu1 - mu2 | |
covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False) | |
if not np.isfinite(covmean).all(): | |
msg = ('fid calculation produces singular product; ' | |
'adding %s to diagonal of cov estimates') % eps | |
logger.info(msg) | |
offset = np.eye(sigma1.shape[0]) * eps | |
covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) | |
# Numerical error might give slight imaginary component | |
if np.iscomplexobj(covmean): | |
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): | |
m = np.max(np.abs(covmean.imag)) | |
raise ValueError('Imaginary component {}'.format(m)) | |
covmean = covmean.real | |
tr_covmean = np.trace(covmean) | |
return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) | |
def compute_stats_for_dataset(self): | |
if self.use_cache and 'mu' in self.cache and 'sigma' in self.cache: | |
logger.info('Use cached statistics for dataset...') | |
return self.cache['mu'], self.cache['sigma'] | |
print() | |
logger.info('Compute statistics for dataset...') | |
image_count = 0 | |
activations = [] | |
for batch in utils.prefetch(self.dataset, n_prefetch=2): | |
act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(batch['image'])) | |
act = jnp.reshape(act, (self.num_local_devices * self.batch_size, -1)) | |
activations.append(act) | |
image_count += self.num_local_devices * self.batch_size | |
if image_count >= self.num_images: | |
break | |
activations = jnp.concatenate(activations, axis=0) | |
activations = activations[:self.num_images] | |
mu = np.mean(activations, axis=0) | |
sigma = np.cov(activations, rowvar=False) | |
self.cache['mu'] = mu | |
self.cache['sigma'] = sigma | |
return mu, sigma | |
def compute_stats_for_generator(self, generator_params, seed_offset): | |
print() | |
logger.info('Compute statistics for generator...') | |
num_batches = int(np.ceil(self.num_images / (self.batch_size * self.num_local_devices))) | |
activations = [] | |
for i in range(num_batches): | |
rng = jax.random.PRNGKey(seed_offset + i) | |
z_latent = jax.random.normal(rng, shape=(self.num_local_devices, self.batch_size, self.z_dim)) | |
labels = None | |
if self.c_dim > 0: | |
labels = jax.random.randint(rng, shape=(self.num_local_devices * self.batch_size,), minval=0, maxval=self.c_dim) | |
labels = jax.nn.one_hot(labels, num_classes=self.c_dim) | |
labels = jnp.reshape(labels, (self.num_local_devices, self.batch_size, self.c_dim)) | |
image = self.generator_apply(generator_params, jax.lax.stop_gradient(z_latent), labels) | |
image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image)) | |
image = 2 * image - 1 | |
act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(image)) | |
act = jnp.reshape(act, (self.num_local_devices * self.batch_size, -1)) | |
activations.append(act) | |
activations = jnp.concatenate(activations, axis=0) | |
activations = activations[:self.num_images] | |
mu = np.mean(activations, axis=0) | |
sigma = np.cov(activations, rowvar=False) | |
return mu, sigma | |