akhaliq's picture
akhaliq HF staff
add files
81170fd
raw
history blame
No virus
6.38 kB
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import numpy as np
import os
import functools
import argparse
import scipy
from tqdm import tqdm
import logging
from . import inception
from . import utils
logger = logging.getLogger(__name__)
class FID:
def __init__(self, generator, dataset, config, use_cache=True, truncation_psi=1.0):
"""
Evaluates the FID score for a given generator and a given dataset.
Implementation mostly taken from https://github.com/matthias-wright/jax-fid
Reference: https://arxiv.org/abs/1706.08500
Args:
generator (nn.Module): Generator network.
dataset (tf.data.Dataset): Dataset containing the real images.
config (argparse.Namespace): Configuration.
use_cache (bool): If True, only compute the activation stats once for the real images and store them.
truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled.
"""
self.num_images = config.num_fid_images
self.batch_size = config.batch_size
self.c_dim = config.c_dim
self.z_dim = config.z_dim
self.dataset = dataset
self.num_devices = jax.device_count()
self.num_local_devices = jax.local_device_count()
self.use_cache = use_cache
if self.use_cache:
self.cache = {}
rng = jax.random.PRNGKey(0)
inception_net = inception.InceptionV3(pretrained=True)
self.inception_params = inception_net.init(rng, jnp.ones((1, config.resolution, config.resolution, 3)))
self.inception_params = flax.jax_utils.replicate(self.inception_params)
#self.inception = jax.jit(functools.partial(model.apply, train=False))
self.inception_apply = jax.pmap(functools.partial(inception_net.apply, train=False), axis_name='batch')
self.generator_apply = jax.pmap(functools.partial(generator.apply, truncation_psi=truncation_psi, train=False, noise_mode='const'), axis_name='batch')
def compute_fid(self, generator_params, seed_offset=0):
generator_params = flax.jax_utils.replicate(generator_params)
mu_real, sigma_real = self.compute_stats_for_dataset()
mu_fake, sigma_fake = self.compute_stats_for_generator(generator_params, seed_offset)
fid_score = self.compute_frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake, eps=1e-6)
return fid_score
def compute_frechet_distance(self, mu1, mu2, sigma1, sigma2, eps=1e-6):
# Taken from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_1d(sigma1)
sigma2 = np.atleast_1d(sigma2)
assert mu1.shape == mu2.shape
assert sigma1.shape == sigma2.shape
diff = mu1 - mu2
covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = ('fid calculation produces singular product; '
'adding %s to diagonal of cov estimates') % eps
logger.info(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
def compute_stats_for_dataset(self):
if self.use_cache and 'mu' in self.cache and 'sigma' in self.cache:
logger.info('Use cached statistics for dataset...')
return self.cache['mu'], self.cache['sigma']
print()
logger.info('Compute statistics for dataset...')
image_count = 0
activations = []
for batch in utils.prefetch(self.dataset, n_prefetch=2):
act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(batch['image']))
act = jnp.reshape(act, (self.num_local_devices * self.batch_size, -1))
activations.append(act)
image_count += self.num_local_devices * self.batch_size
if image_count >= self.num_images:
break
activations = jnp.concatenate(activations, axis=0)
activations = activations[:self.num_images]
mu = np.mean(activations, axis=0)
sigma = np.cov(activations, rowvar=False)
self.cache['mu'] = mu
self.cache['sigma'] = sigma
return mu, sigma
def compute_stats_for_generator(self, generator_params, seed_offset):
print()
logger.info('Compute statistics for generator...')
num_batches = int(np.ceil(self.num_images / (self.batch_size * self.num_local_devices)))
activations = []
for i in range(num_batches):
rng = jax.random.PRNGKey(seed_offset + i)
z_latent = jax.random.normal(rng, shape=(self.num_local_devices, self.batch_size, self.z_dim))
labels = None
if self.c_dim > 0:
labels = jax.random.randint(rng, shape=(self.num_local_devices * self.batch_size,), minval=0, maxval=self.c_dim)
labels = jax.nn.one_hot(labels, num_classes=self.c_dim)
labels = jnp.reshape(labels, (self.num_local_devices, self.batch_size, self.c_dim))
image = self.generator_apply(generator_params, jax.lax.stop_gradient(z_latent), labels)
image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image))
image = 2 * image - 1
act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(image))
act = jnp.reshape(act, (self.num_local_devices * self.batch_size, -1))
activations.append(act)
activations = jnp.concatenate(activations, axis=0)
activations = activations[:self.num_images]
mu = np.mean(activations, axis=0)
sigma = np.cov(activations, rowvar=False)
return mu, sigma