ic_gan / inference /sample.py
ArantxaCasanova
First model version
a00ee36
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# All contributions by Andy Brock:
# Copyright (c) 2019 Andy Brock
#
# MIT License
""" BigGAN: The Authorized Unofficial PyTorch release
Code by A. Brock and A. Andonian
This code is an unofficial reimplementation of
"Large-Scale GAN Training for High Fidelity Natural Image Synthesis,"
by A. Brock, J. Donahue, and K. Simonyan (arXiv 1809.11096).
Let's go.
"""
import os
import numpy as np
from tqdm import tqdm, trange
import json
from imageio import imwrite as imsave
# Import my stuff
import sys
sys.path.insert(1, os.path.join(sys.path[0], ".."))
import inference.utils as inference_utils
import BigGAN_PyTorch.utils as biggan_utils
class Tester:
def __init__(self, config):
self.config = vars(config) if not isinstance(config, dict) else config
def __call__(self) -> float:
# Seed RNG
biggan_utils.seed_rng(self.config["seed"])
import torch
# Setup cudnn.benchmark for free speed
torch.backends.cudnn.benchmark = True
self.config = biggan_utils.update_config_roots(
self.config, change_weight_folder=False
)
# Prepare root folders if necessary
biggan_utils.prepare_root(self.config)
# Load model
self.G, self.config = inference_utils.load_model_inference(self.config)
biggan_utils.count_parameters(self.G)
# Get sampling function and reference statistics for FID
print("Eval reference set is ", self.config["eval_reference_set"])
sample, im_reference_filename = inference_utils.get_sampling_funct(
self.config,
self.G,
instance_set=self.config["eval_instance_set"],
reference_set=self.config["eval_reference_set"],
which_dataset=self.config["which_dataset"],
)
if config["which_dataset"] == "coco":
image_format = "jpg"
else:
image_format = "png"
if (
self.config["eval_instance_set"] == "val"
and config["which_dataset"] == "coco"
):
# using evaluation set
test_part = True
else:
test_part = False
path_samples = os.path.join(
self.config["samples_root"],
self.config["experiment_name"],
"%s_images_seed%i%s%s%s"
% (
config["which_dataset"],
config["seed"],
"_test" if test_part else "",
"_hd" + str(self.config["filter_hd"])
if self.config["filter_hd"] > -1
else "",
""
if self.config["kmeans_subsampled"] == -1
else "_" + str(self.config["kmeans_subsampled"]) + "centers",
),
)
print("Path samples will be ", path_samples)
if not os.path.exists(path_samples):
os.makedirs(path_samples)
if not os.path.exists(
os.path.join(self.config["samples_root"], self.config["experiment_name"])
):
os.mkdir(
os.path.join(
self.config["samples_root"], self.config["experiment_name"]
)
)
print(
"Sampling %d images and saving them with %s format..."
% (self.config["sample_num_npz"], image_format)
)
counter_i = 0
for i in trange(
int(
np.ceil(
self.config["sample_num_npz"] / float(self.config["batch_size"])
)
)
):
with torch.no_grad():
images, labels, _ = sample()
fake_imgs = images.cpu().detach().numpy().transpose(0, 2, 3, 1)
if self.config["model_backbone"] == "biggan":
fake_imgs = fake_imgs * 0.5 + 0.5
elif self.config["model_backbone"] == "stylegan2":
fake_imgs = np.clip((fake_imgs * 127.5 + 128), 0, 255).astype(
np.uint8
)
for fake_img in fake_imgs:
imsave(
"%s/%06d.%s" % (path_samples, counter_i, image_format), fake_img
)
counter_i += 1
if counter_i >= self.config["sample_num_npz"]:
break
if __name__ == "__main__":
parser = biggan_utils.prepare_parser()
parser = biggan_utils.add_sample_parser(parser)
parser = inference_utils.add_backbone_parser(parser)
config = vars(parser.parse_args())
config["n_classes"] = 1000
if config["json_config"] != "":
data = json.load(open(config["json_config"]))
for key in data.keys():
if "exp_name" in key:
config["experiment_name"] = data[key]
else:
config[key] = data[key]
else:
print("No json file to load configuration from")
tester = Tester(config)
tester()