GANime / ganime /metrics /image.py
Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
No virus
2.43 kB
import numpy as np
import tensorflow as tf
from scipy import linalg
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
from tqdm.auto import tqdm
inceptionv3 = InceptionV3(include_top=False, weights="imagenet", pooling="avg")
def resize_images(images, new_shape):
images = tf.image.resize(images, new_shape)
return images
def calculate_fid(real_embeddings, generated_embeddings):
# calculate mean and covariance statistics
mu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings, rowvar=False)
mu2, sigma2 = generated_embeddings.mean(axis=0), np.cov(
generated_embeddings, rowvar=False
)
# calculate sum squared difference between means
ssdiff = np.sum((mu1 - mu2) ** 2.0)
# calculate sqrt of product between cov
covmean = linalg.sqrtm(sigma1.dot(sigma2))
# check and correct imaginary numbers from sqrt
if np.iscomplexobj(covmean):
covmean = covmean.real
# calculate score
fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
return fid
def calculate_images_metrics(dataset, model, total_length):
fake_embeddings = []
real_embeddings = []
psnrs = []
ssims = []
for sample in tqdm(dataset, total=total_length):
generated = model(sample[0], training=False)[0]
generated, real = generated, sample[0]
real_resized = resize_images(real, (299, 299))
generated_resized = resize_images(generated, (299, 299))
real_activations = inceptionv3(real_resized, training=False)
generated_activations = inceptionv3(generated_resized, training=False)
fake_embeddings.append(generated_activations)
real_embeddings.append(real_activations)
fake_scaled = tf.cast(((generated * 0.5) + 1) * 255, tf.uint8)
real_scaled = tf.cast(((real * 0.5) + 1) * 255, tf.uint8)
psnrs.append(tf.image.psnr(fake_scaled, real_scaled, 255).numpy())
ssims.append(tf.image.ssim(fake_scaled, real_scaled, 255).numpy())
fid = calculate_fid(
tf.concat(fake_embeddings, axis=0).numpy(),
tf.concat(real_embeddings, axis=0).numpy(),
)
# kid = calculate_kid(
# tf.concat(fake_embeddings, axis=0).numpy(),
# tf.concat(real_embeddings, axis=0).numpy(),
# )
psnr = np.array(psnrs).mean()
ssim = np.array(ssims).mean()
return {"fid": fid, "ssim": ssim, "psnr": psnr}