ic_gan / data_utils /inception_utils.py
ArantxaCasanova
First model version
a00ee36
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# All contributions by Andy Brock:
# Copyright (c) 2019 Andy Brock
#
# MIT License
""" Inception utilities
This file contains methods for calculating IS and FID, using either
the original numpy code or an accelerated fully-pytorch version that
uses a fast newton-schulz approximation for the matrix sqrt. There are also
methods for acquiring a desired number of samples from the Generator,
and parallelizing the inbuilt PyTorch inception network.
NOTE that Inception Scores and FIDs calculated using these methods will
*not* be directly comparable to values calculated using the original TF
IS/FID code. You *must* use the TF model if you wish to report and compare
numbers. This code tends to produce IS values that are 5-10% lower than
those obtained through TF.
"""
import os
import numpy as np
from scipy import linalg # For numpy FID
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter as P
from torchvision.models.inception import inception_v3
import sys
sys.path.insert(1, os.path.join(sys.path[0], ".."))
from data_utils.compute_pdrc import compute_prdc
# Module that wraps the inception network to enable use with dataparallel and
# returning pool features and logits.
class WrapInception(nn.Module):
def __init__(self, net):
super(WrapInception, self).__init__()
self.net = net
self.mean = P(
torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1), requires_grad=False
)
self.std = P(
torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1), requires_grad=False
)
def forward(self, x):
# Normalize x
x = (x + 1.0) / 2.0
x = (x - self.mean) / self.std
# Upsample if necessary
if x.shape[2] != 299 or x.shape[3] != 299:
x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=True)
# 299 x 299 x 3
x = self.net.Conv2d_1a_3x3(x)
# 149 x 149 x 32
x = self.net.Conv2d_2a_3x3(x)
# 147 x 147 x 32
x = self.net.Conv2d_2b_3x3(x)
# 147 x 147 x 64
x = F.max_pool2d(x, kernel_size=3, stride=2)
# 73 x 73 x 64
x = self.net.Conv2d_3b_1x1(x)
# 73 x 73 x 80
x = self.net.Conv2d_4a_3x3(x)
# 71 x 71 x 192
x = F.max_pool2d(x, kernel_size=3, stride=2)
# 35 x 35 x 192
x = self.net.Mixed_5b(x)
# 35 x 35 x 256
x = self.net.Mixed_5c(x)
# 35 x 35 x 288
x = self.net.Mixed_5d(x)
# 35 x 35 x 288
x = self.net.Mixed_6a(x)
# 17 x 17 x 768
x = self.net.Mixed_6b(x)
# 17 x 17 x 768
x = self.net.Mixed_6c(x)
# 17 x 17 x 768
x = self.net.Mixed_6d(x)
# 17 x 17 x 768
x = self.net.Mixed_6e(x)
# 17 x 17 x 768
# 17 x 17 x 768
x = self.net.Mixed_7a(x)
# 8 x 8 x 1280
x = self.net.Mixed_7b(x)
# 8 x 8 x 2048
x = self.net.Mixed_7c(x)
# 8 x 8 x 2048
pool = torch.mean(x.view(x.size(0), x.size(1), -1), 2)
# 1 x 1 x 2048
logits = self.net.fc(F.dropout(pool, training=False).view(pool.size(0), -1))
# 1000 (num_classes)
return pool, logits
# A pytorch implementation of cov, from Modar M. Alfadly
# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2
def torch_cov(m, rowvar=False):
"""Estimate a covariance matrix given data.
Covariance indicates the level to which two variables vary together.
If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
then the covariance matrix element `C_{ij}` is the covariance of
`x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.
Parameters
----------
m: A 1-D or 2-D array containing multiple variables and observations.
Each row of `m` represents a variable, and each column a single
observation of all those variables.
rowvar: If `rowvar` is True, then each row represents a
variable, with observations in the columns. Otherwise, the
relationship is transposed: each column represents a variable,
while the rows contain observations.
Returns
-------
The covariance matrix of the variables.
"""
if m.dim() > 2:
raise ValueError("m has more than 2 dimensions")
if m.dim() < 2:
m = m.view(1, -1)
if not rowvar and m.size(0) != 1:
m = m.t()
# m = m.type(torch.double) # uncomment this line if desired
fact = 1.0 / (m.size(1) - 1)
m -= torch.mean(m, dim=1, keepdim=True)
mt = m.t() # if complex: mt = m.t().conj()
return fact * m.matmul(mt).squeeze()
# Pytorch implementation of matrix sqrt, from Tsung-Yu Lin, and Subhransu Maji
# https://github.com/msubhransu/matrix-sqrt
def sqrt_newton_schulz(A, numIters, dtype=None):
with torch.no_grad():
if dtype is None:
dtype = A.type()
batchSize = A.shape[0]
dim = A.shape[1]
normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt()
Y = A.div(normA.view(batchSize, 1, 1).expand_as(A))
I = torch.eye(dim, dim).view(1, dim, dim).repeat(batchSize, 1, 1).type(dtype)
Z = torch.eye(dim, dim).view(1, dim, dim).repeat(batchSize, 1, 1).type(dtype)
for i in range(numIters):
T = 0.5 * (3.0 * I - Z.bmm(Y))
Y = Y.bmm(T)
Z = T.bmm(Z)
sA = Y * torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A)
return sA
# FID calculator from TTUR--consider replacing this with GPU-accelerated cov
# calculations using torch?
def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""Numpy implementation of the Frechet Distance.
Taken from https://github.com/bioinf-jku/TTUR
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Parameters
----------
mu1 : Numpy array containing the activations of a layer of the
inception net (like returned by the function 'get_predictions')
for generated samples.
mu2 : The sample mean over activations, precalculated on an
representive data set.
sigma1: The covariance matrix over activations for generated samples.
sigma2: The covariance matrix over activations, precalculated on an
representive data set.
Returns
-------
The Frechet Distance (float).
"""
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert (
mu1.shape == mu2.shape
), "Training and test mean vectors have different lengths"
assert (
sigma1.shape == sigma2.shape
), "Training and test covariances have different dimensions"
diff = mu1 - mu2
# Product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = (
"fid calculation produces singular product; "
"adding %s to diagonal of cov estimates"
) % eps
print(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
print("wat")
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError("Imaginary component {}".format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
out = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
return out
def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""Pytorch implementation of the Frechet Distance.
Taken from https://github.com/bioinf-jku/TTUR
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Parameters
----------
mu1 : Numpy array containing the activations of a layer of the
inception net (like returned by the function 'get_predictions')
for generated samples.
mu2 : The sample mean over activations, precalculated on an
representive data set.
sigma1: The covariance matrix over activations for generated samples.
sigma2: The covariance matrix over activations, precalculated on an
representive data set.
Returns
-------
The Frechet Distance (float).
"""
assert (
mu1.shape == mu2.shape
), "Training and test mean vectors have different lengths"
assert (
sigma1.shape == sigma2.shape
), "Training and test covariances have different dimensions"
diff = mu1 - mu2
# Run 50 itrs of newton-schulz to get the matrix sqrt of sigma1 dot sigma2
covmean = sqrt_newton_schulz(sigma1.mm(sigma2).unsqueeze(0), 50).squeeze()
out = (
diff.dot(diff)
+ torch.trace(sigma1)
+ torch.trace(sigma2)
- 2 * torch.trace(covmean)
)
return out
# Calculate Inception Score mean + std given softmax'd logits and number of splits
def calculate_inception_score(pred, num_splits=10):
scores = []
for index in range(num_splits):
pred_chunk = pred[
index
* (pred.shape[0] // num_splits) : (index + 1)
* (pred.shape[0] // num_splits),
:,
]
kl_inception = pred_chunk * (
np.log(pred_chunk) - np.log(np.expand_dims(np.mean(pred_chunk, 0), 0))
)
kl_inception = np.mean(np.sum(kl_inception, 1))
scores.append(np.exp(kl_inception))
return np.mean(scores), np.std(scores)
# Loop and run the sampler and the net until it accumulates num_inception_images
# activations. Return the pool, the logits, and the labels (if one wants
# Inception Accuracy the labels of the generated class will be needed)
def accumulate_inception_activations(sample, net, num_inception_images=50000, model_backbone='biggan'):
pool, logits, labels = [], [], []
while (torch.cat(logits, 0).shape[0] if len(logits) else 0) < num_inception_images:
with torch.no_grad():
images, labels_val, _ = sample()
if model_backbone == 'stylegan2':
images = torch.clamp((images * 127.5 + 128), 0, 255)
images = ((images / 255) - 0.5) * 2
if labels_val is not None:
labels_val = labels_val.long()
pool_val, logits_val = net(images.float())
pool += [pool_val]
logits += [F.softmax(logits_val, 1)]
labels += [labels_val]
return (
torch.cat(pool, 0),
torch.cat(logits, 0),
torch.cat(labels, 0)
if labels[0] is not None
else torch.zeros(torch.cat(logits, 0).shape[0]).long(),
)
### Iterates over the real data in the loader to return the labels and pool activations.
def accumulate_features(net, loader, num_inception_images=50000, device="cuda"):
pool_real = []
for i, batch in enumerate(loader):
x = batch[0]
with torch.no_grad():
x = x.to(device).float()
pool_real += [net(x)[0].cpu()]
if (
torch.cat(pool_real, 0).shape[0] if len(pool_real) else 0
) >= num_inception_images:
break
return torch.cat(pool_real, 0).cpu()[:num_inception_images]
# Load and wrap the Inception model
def load_inception_net(parallel=False, device="cuda"):
inception_model = inception_v3(pretrained=True, transform_input=False)
inception_model = WrapInception(inception_model.eval()).to(device)
if parallel:
print("Parallelizing Inception module...")
inception_model = nn.DataParallel(inception_model)
return inception_model
# This produces a function which takes in an iterator which returns a set number of samples
# and iterates until it accumulates config['num_inception_images'] images.
# The iterator can return samples with a different batch size than used in
# training, using the setting confg['inception_batchsize']
def prepare_inception_metrics(
dataset,
samples_per_class,
parallel,
no_fid=False,
data_root="",
split_name="",
stratified_fid=False,
prdc=False,
device="cuda",
backbone='biggan',
):
# Load metrics; this is intentionally not in a try-except loop so that
# the script will crash here if it cannot find the Inception moments.
# By default, remove the "hdf5" from dataset
print(
"Loading dataset inception moments from ",
os.path.join(
data_root, dataset + "_" + "inception_moments" + split_name + ".npz"
),
)
stats = np.load(
os.path.join(
data_root, dataset + "_" + "inception_moments" + split_name + ".npz"
)
)
data_mu = stats["mu"]
data_sigma = stats["sigma"]
if stratified_fid:
many_stats = np.load(
os.path.join(data_root, dataset + "_many_inception_moments.npz")
)
low_stats = np.load(
os.path.join(data_root, dataset + "_low_inception_moments.npz")
)
few_stats = np.load(
os.path.join(data_root, dataset + "_few_inception_moments.npz")
)
# Load network
net = load_inception_net(parallel, device=device)
def get_inception_metrics(
sample,
num_inception_images,
num_splits=10,
prints=True,
use_torch=True,
loader_ref=None,
num_pr_images=10000
):
if prints:
print("Gathering activations...")
pool, logits, labels_val = accumulate_inception_activations(
sample, net, num_inception_images, backbone
)
# Obtain features for the real ground-truth data
if prdc and loader_ref is not None:
pool_real = accumulate_features(
net, loader_ref, num_inception_images, device=device
)
print("Subsampling %i samples for prdc metrics!" % (num_pr_images))
idxs_selected = np.random.choice(
range(len(pool_real)), num_pr_images, replace=False
)
prdc_metrics = compute_prdc(
pool_real[idxs_selected], pool[idxs_selected].cpu(), 5
)
if prints:
print("Calculating Inception Score...")
IS_mean, IS_std = calculate_inception_score(logits.cpu().numpy(), num_splits)
if no_fid:
FID = 9999.0
else:
if prints:
print("Calculating means and covariances...")
FID = compute_fid(
pool.clone(), data_mu, data_sigma, prints, use_torch, device=device
)
# Obtain stratified FID metrics for ImageNet-LT dataset.
stratified_fid_list = []
if stratified_fid:
labels_val = labels_val.cpu()
pool = pool.cpu()
for stats, strat_name in zip(
[many_stats, low_stats, few_stats], ["_many", "_low", "_few"]
):
if strat_name == "_many":
pool_ = pool[samples_per_class[labels_val] >= 100]
print("For many-shot, selecting ", len(pool_), " samples.")
elif strat_name == "_low":
pool_ = pool[samples_per_class[labels_val] < 100]
labels_ = labels_val[samples_per_class[labels_val] < 100]
pool_ = pool_[samples_per_class[labels_] > 20]
print("For low-shot, selecting ", len(pool_), " samples.")
elif strat_name == "_few":
pool_ = pool[samples_per_class[labels_val] <= 20]
print("For few-shot, selecting ", len(pool_), " samples.")
# import pdb
# pdb.set_trace()
stratified_fid_list.append(
compute_fid(pool_, stats["mu"], stats["sigma"], prints, False)
)
del pool_
# Delete mu, sigma, pool, logits, and labels, just in case
del pool, logits, labels_val
if prdc:
return IS_mean, IS_std, FID, stratified_fid_list, prdc_metrics
else:
return IS_mean, IS_std, FID, stratified_fid_list
return get_inception_metrics
def compute_fid(pool, data_mu, data_sigma, prints, use_torch, device="cuda"):
if use_torch:
mu, sigma = torch.mean(pool, 0), torch_cov(pool, rowvar=False)
else:
mu, sigma = (
np.mean(pool.cpu().numpy(), axis=0),
np.cov(pool.cpu().numpy(), rowvar=False),
)
if prints:
print("Covariances calculated, getting FID...")
if use_torch:
FID = torch_calculate_frechet_distance(
mu,
sigma,
torch.tensor(data_mu).float().to(device),
torch.tensor(data_sigma).float().to(device),
)
FID = float(FID.cpu().numpy())
else:
FID = numpy_calculate_frechet_distance(mu, sigma, data_mu, data_sigma)
del mu, sigma
return FID