ic_gan / inference /test.py
ArantxaCasanova
First model version
a00ee36
raw
history blame
No virus
11.4 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# All contributions by Andy Brock:
# Copyright (c) 2019 Andy Brock
#
# MIT License
import os
import numpy as np
from tqdm import trange
import pickle
import json
import sys
print(sys.path[0])
sys.path.insert(1, os.path.join(sys.path[0], ".."))
import data_utils.inception_utils as inception_utils
import data_utils.utils as data_utils
import inference.utils as inference_utils
import BigGAN_PyTorch.utils as biggan_utils
LOCAL = False
try:
import submitit
except:
print(
"No submitit package found! Defaulting to executing the script in the local machine"
)
LOCAL = True
class Tester:
def __init__(self, config):
self.config = vars(config) if not isinstance(config, dict) else config
def __call__(self) -> float:
# Seed RNG
biggan_utils.seed_rng(self.config["seed"])
import torch
# Setup cudnn.benchmark for free speed
torch.backends.cudnn.benchmark = True
self.config = biggan_utils.update_config_roots(
self.config, change_weight_folder=False
)
# Prepare root folders if necessary
biggan_utils.prepare_root(self.config)
# Load model
self.G, self.config = inference_utils.load_model_inference(self.config)
biggan_utils.count_parameters(self.G)
# Get sampling function and reference statistics for FID
print("Eval reference set is ", self.config["eval_reference_set"])
sample, im_reference_filename = inference_utils.get_sampling_funct(
self.config,
self.G,
instance_set=self.config["eval_instance_set"],
reference_set=self.config["eval_reference_set"],
which_dataset=self.config["which_dataset"],
)
# Obtain additional variables for testing
if self.config["eval_reference_set"] == "val" and self.config["longtail"]:
stratified_m = True
else:
stratified_m = False
if self.config["longtail"]:
samples_per_class = np.load(
"../BigGAN_PyTorch/imagenet_lt/imagenet_lt_samples_per_class.npy",
allow_pickle=True,
)
else:
samples_per_class = None
# Prepare inception metrics function
get_inception_metrics = inception_utils.prepare_inception_metrics(
im_reference_filename,
samples_per_class,
self.config["parallel"],
self.config["no_fid"],
self.config["data_root"],
prdc=self.config["eval_prdc"],
stratified_fid=stratified_m,
backbone=self.config["model_backbone"],
)
# If computing PRDC, we need a loader to obtain reference Inception features
if self.config["eval_prdc"]:
prdc_ref_set = data_utils.get_dataset_hdf5(
**{
**self.config,
"data_path": self.config["data_root"],
"load_in_mem_feats": self.config["load_in_mem"],
"kmeans_subsampled": -1,
"test_part": True
if self.config["which_dataset"] == "coco"
and self.config["eval_reference_set"] == "val"
else False,
"split": self.config["eval_reference_set"],
"ddp": False,
}
)
prdc_loader = data_utils.get_dataloader(
**{
**self.config,
"dataset": prdc_ref_set,
"batch_size": self.config["batch_size"],
"use_checkpointable_sampler": False,
"shuffle": True,
"drop_last": False,
}
)
else:
prdc_loader = None
# Get metrics
eval_metrics = get_inception_metrics(
sample,
num_inception_images=self.config["num_inception_images"],
num_splits=10,
prints=False,
loader_ref=prdc_loader,
num_pr_images=self.config["num_inception_images"]
if (
self.config["longtail"] and self.config["eval_reference_set"] == "train"
)
else 10000,
)
eval_metrics_dict = dict()
if self.config["eval_prdc"]:
IS_mean, IS_std, FID, stratified_FID, prdc_metrics = eval_metrics
else:
IS_mean, IS_std, FID, stratified_FID = eval_metrics
if stratified_m:
eval_metrics_dict["stratified_FID"] = stratified_FID
eval_metrics_dict["IS_mean"] = IS_mean
eval_metrics_dict["IS_std"] = IS_std
eval_metrics_dict["FID"] = FID
print(eval_metrics_dict)
if self.config["eval_prdc"]:
eval_metrics_dict = {**prdc_metrics, **eval_metrics_dict}
add_suffix = ""
if self.config["z_var"] != 1.0:
add_suffix = "_z_var" + str(self.config["z_var"])
if not os.path.exists(
os.path.join(self.config["samples_root"], self.config["experiment_name"])
):
os.mkdir(
os.path.join(
self.config["samples_root"], self.config["experiment_name"]
)
)
# Save metrics in file
print('Saving metrics in ', os.path.join(
self.config["samples_root"],
self.config["experiment_name"]))
np.save(
os.path.join(
self.config["samples_root"],
self.config["experiment_name"],
"eval_metrics_reference_"
+ self.config["eval_reference_set"]
+ "_instances_"
+ self.config["eval_instance_set"]
+ "_kmeans"
+ str(self.config["kmeans_subsampled"])
+ "_seed"
+ str(self.config["seed"])
+ add_suffix
+ ".npy",
),
eval_metrics_dict,
)
print("Computed metrics:")
for key, value in eval_metrics_dict.items():
print(key, ": ", value)
if self.config["sample_npz"]:
# Sample a number of images and save them to an NPZ, for use with TF-Inception
# Lists to hold images and labels for images
samples_path = os.path.join(
self.config["samples_root"], self.config["experiment_name"]
)
if not os.path.exists(samples_path):
os.mkdir(samples_path)
x, y = [], []
print(
"Sampling %d images and saving them to npz..."
% self.config["sample_num_npz"]
)
dict_tosave = {}
for i in trange(
int(
np.ceil(
self.config["sample_num_npz"] / float(self.config["batch_size"])
)
)
):
with torch.no_grad():
images, labels, _ = sample()
if self.config["model_backbone"] == "stylegan2":
images = torch.clamp((images * 127.5 + 128), 0, 255)
images = ((images / 255) - 0.5) * 2
x += [images.cpu().numpy()]
if self.config["class_cond"]:
y += [labels.cpu().numpy()]
if self.config["which_dataset"] == "imagenet":
x = np.concatenate(x, 0)[: self.config["sample_num_npz"]]
if self.config["class_cond"]:
y = np.concatenate(y, 0)[: self.config["sample_num_npz"]]
np_filename = "%s/samples%s_seed%i.pickle" % (
samples_path,
"_kmeans" + str(self.config["kmeans_subsampled"])
if self.config["kmeans_subsampled"] > -1
else "",
self.config["seed"],
)
print("Saving npy to %s..." % np_filename)
dict_tosave["x"] = x
dict_tosave["y"] = y
file_to_store = open(np_filename, "wb")
pickle.dump(dict_tosave, file_to_store, protocol=4)
file_to_store.close()
if (
self.config["longtail"]
and self.config["eval_reference_set"] == "val"
):
print("Additionally storing stratified samples")
for strat_name in ["_many", "_low", "_few"]:
np_filename = "%s/%s/samples%s_seed%i_strat%s.pickle" % (
self.config["samples_root"],
self.config["experiment_name"],
"_kmeans" + str(self.config["kmeans_subsampled"])
if self.config["kmeans_subsampled"] > -1
else "",
self.config["seed"],
strat_name,
)
print(np_filename)
if strat_name == "_many":
x_ = x[samples_per_class[y] >= 100]
y_ = y[samples_per_class[y] >= 100]
elif strat_name == "_low":
x_ = x[samples_per_class[y] < 100]
y_ = y[samples_per_class[y] < 100]
x_ = x_[samples_per_class[y_] > 20]
y_ = y_[samples_per_class[y_] > 20]
elif strat_name == "_few":
x_ = x[samples_per_class[y] <= 20]
y_ = y[samples_per_class[y] <= 20]
dict_tosave = {}
dict_tosave["x"] = x_
dict_tosave["y"] = y_
file_to_store = open(np_filename, "wb")
pickle.dump(dict_tosave, file_to_store, protocol=4)
file_to_store.close()
if __name__ == "__main__":
parser = biggan_utils.prepare_parser()
parser = biggan_utils.add_sample_parser(parser)
parser = inference_utils.add_backbone_parser(parser)
config = vars(parser.parse_args())
config["n_classes"] = 1000
if config["json_config"] != "":
data = json.load(open(config["json_config"]))
for key in data.keys():
if "exp_name" in key:
config["experiment_name"] = data[key]
else:
config[key] = data[key]
else:
print("No json file to load configuration from")
tester = Tester(config)
if config["run_setup"] == "local_debug": # or LOCAL:
tester()
else:
executor = submitit.SlurmExecutor(
folder=config["slurm_logdir"], max_num_timeout=10
)
executor.update_parameters(
gpus_per_node=1,
partition=config["partition"],
cpus_per_task=8,
mem=128000,
time=30,
job_name="testing_" + config["experiment_name"],
)
executor.submit(tester)
import time
time.sleep(1)