|
import numpy as np |
|
import tensorflow as tf |
|
from scipy import linalg |
|
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input |
|
from tqdm.auto import tqdm |
|
|
|
inceptionv3 = InceptionV3(include_top=False, weights="imagenet", pooling="avg") |
|
|
|
|
|
def resize_images(images, new_shape): |
|
images = tf.image.resize(images, new_shape) |
|
return images |
|
|
|
|
|
def calculate_fid(real_embeddings, generated_embeddings): |
|
|
|
mu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings, rowvar=False) |
|
mu2, sigma2 = generated_embeddings.mean(axis=0), np.cov( |
|
generated_embeddings, rowvar=False |
|
) |
|
|
|
ssdiff = np.sum((mu1 - mu2) ** 2.0) |
|
|
|
covmean = linalg.sqrtm(sigma1.dot(sigma2)) |
|
|
|
if np.iscomplexobj(covmean): |
|
covmean = covmean.real |
|
|
|
fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) |
|
return fid |
|
|
|
|
|
def calculate_images_metrics(dataset, model, total_length): |
|
fake_embeddings = [] |
|
real_embeddings = [] |
|
|
|
psnrs = [] |
|
ssims = [] |
|
|
|
for sample in tqdm(dataset, total=total_length): |
|
generated = model(sample[0], training=False)[0] |
|
generated, real = generated, sample[0] |
|
|
|
real_resized = resize_images(real, (299, 299)) |
|
generated_resized = resize_images(generated, (299, 299)) |
|
|
|
real_activations = inceptionv3(real_resized, training=False) |
|
generated_activations = inceptionv3(generated_resized, training=False) |
|
fake_embeddings.append(generated_activations) |
|
real_embeddings.append(real_activations) |
|
|
|
fake_scaled = tf.cast(((generated * 0.5) + 1) * 255, tf.uint8) |
|
real_scaled = tf.cast(((real * 0.5) + 1) * 255, tf.uint8) |
|
|
|
psnrs.append(tf.image.psnr(fake_scaled, real_scaled, 255).numpy()) |
|
ssims.append(tf.image.ssim(fake_scaled, real_scaled, 255).numpy()) |
|
|
|
fid = calculate_fid( |
|
tf.concat(fake_embeddings, axis=0).numpy(), |
|
tf.concat(real_embeddings, axis=0).numpy(), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
psnr = np.array(psnrs).mean() |
|
ssim = np.array(ssims).mean() |
|
return {"fid": fid, "ssim": ssim, "psnr": psnr} |
|
|