|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
|
|
sys.path.insert(1, os.path.join(sys.path[0], "..")) |
|
|
|
import numpy as np |
|
import functools |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
import stylegan2_ada_pytorch.dnnlib as dnnlib |
|
import stylegan2_ada_pytorch.legacy as legacy |
|
import BigGAN_PyTorch.utils as biggan_utils |
|
import BigGAN_PyTorch.BigGAN as BigGANModel |
|
import data_utils.utils as data_utils |
|
|
|
|
|
def get_sampling_funct( |
|
config, |
|
generator, |
|
instance_set="train", |
|
reference_set="train", |
|
which_dataset="imagenet", |
|
): |
|
"""It prepares the generation sampling function and the inception moments filename. |
|
|
|
Arguments |
|
--------- |
|
config: dict |
|
Dictionary with configuration parameters. |
|
generator: torch.nn.module |
|
Generator network. |
|
instance_set: str, optional |
|
If `train`, build a dataset with the training split. |
|
If `val`, build a dataset with the validation split. |
|
reference_set: str, optional |
|
If `train`, use training as a reference to compute metrics. |
|
If `val`, use validation as a reference to compute metrics. |
|
which_dataset: str, optional |
|
Dataset name |
|
Returns |
|
------- |
|
sample_, im_filename, dataset |
|
sample_: function |
|
Function to sample generated images from. |
|
im_filename: str |
|
Filename where to find the inception moments used to compute the FID metric |
|
(with Pytorch code). |
|
|
|
""" |
|
|
|
|
|
|
|
if config["longtail"]: |
|
class_probabilities = np.load( |
|
"../BigGAN_PyTorch/imagenet_lt/imagenet_lt_class_prob.npy", |
|
allow_pickle=True, |
|
) |
|
samples_per_class = np.load( |
|
"../BigGAN_PyTorch/imagenet_lt/imagenet_lt_samples_per_class.npy", |
|
allow_pickle=True, |
|
) |
|
else: |
|
class_probabilities, samples_per_class = None, None |
|
|
|
if (reference_set == "val" and instance_set == "val") and config[ |
|
"which_dataset" |
|
] == "coco": |
|
|
|
test_part = True |
|
else: |
|
test_part = False |
|
|
|
|
|
z_, y_ = data_utils.prepare_z_y( |
|
config["batch_size"], |
|
generator.dim_z if config["model_backbone"] == "biggan" else 512, |
|
config["n_classes"], |
|
device="cuda", |
|
fp16=config["G_fp16"], |
|
longtail_gen=config["longtail"] if reference_set == "train" else False, |
|
z_var=config["z_var"], |
|
class_probabilities=class_probabilities, |
|
) |
|
|
|
|
|
if config["instance_cond"]: |
|
dataset = data_utils.get_dataset_hdf5( |
|
**{ |
|
**config, |
|
"data_path": config["data_root"], |
|
"batch_size": config["batch_size"], |
|
"load_in_mem_feats": config["load_in_mem"], |
|
"split": instance_set, |
|
"test_part": test_part, |
|
"augment": False, |
|
"ddp": False, |
|
} |
|
) |
|
|
|
else: |
|
dataset = None |
|
|
|
|
|
|
|
|
|
|
|
weights_sampling = None |
|
nn_sampling_strategy = "instance_balance" |
|
if config["instance_cond"] and config["class_cond"] and config["longtail"]: |
|
nn_sampling_strategy = "nnclass_balance" |
|
if reference_set == "val": |
|
print("Sampling classes uniformly for generator.") |
|
|
|
weights_sampling = None |
|
else: |
|
print("Balancing with weights=samples_per_class (long-tailed).") |
|
weights_sampling = samples_per_class |
|
|
|
|
|
sample_conditioning = functools.partial( |
|
data_utils.sample_conditioning_values, |
|
z_=z_, |
|
y_=y_, |
|
constant_conditioning=config["constant_conditioning"], |
|
batch_size=config["batch_size"], |
|
weights_sampling=weights_sampling, |
|
dataset=dataset, |
|
class_cond=config["class_cond"], |
|
instance_cond=config["instance_cond"], |
|
nn_sampling_strategy=nn_sampling_strategy, |
|
) |
|
|
|
|
|
sample_ = functools.partial( |
|
sample, |
|
generator, |
|
sample_conditioning_func=sample_conditioning, |
|
config=config, |
|
class_cond=config["class_cond"], |
|
instance_cond=config["instance_cond"], |
|
backbone=config["model_backbone"], |
|
truncation_value=config["z_var"], |
|
) |
|
|
|
|
|
im_prefix = "I" if which_dataset == "imagenet" else "COCO" |
|
if reference_set == "train": |
|
im_filename = "%s%i_%s" % ( |
|
im_prefix, |
|
config["resolution"], |
|
"" if not config["longtail"] else "longtail", |
|
) |
|
else: |
|
im_filename = "%s%i_%s%s" % ( |
|
im_prefix, |
|
config["resolution"], |
|
"_val", |
|
"_test" if test_part else "", |
|
) |
|
print("Using ", im_filename, "for Inception metrics.") |
|
return sample_, im_filename |
|
|
|
|
|
def sample( |
|
generator, |
|
sample_conditioning_func, |
|
config, |
|
class_cond=True, |
|
instance_cond=False, |
|
device="cuda", |
|
backbone="biggan", |
|
truncation_value=1.0, |
|
): |
|
"""It samples generated images from the model, given the input noise (and conditioning). |
|
|
|
Arguments |
|
--------- |
|
generator: torch.nn.module |
|
Generator network. |
|
sample_conditioning_func: function |
|
A function that samples and outputs the conditionings to be fed to the generator. |
|
config: dict |
|
Dictionary with configuration parameters. |
|
class_cond: bool, optional |
|
If True, use class labels to condition the generator. |
|
instance_cond: bool, optional |
|
If True, use instance features to condition the generator. |
|
device: str, optional |
|
Device name |
|
backbone: str, optional |
|
Name of the backbone architecture to use ("biggan" or "stylegan2"). |
|
truncation_value: float, optional |
|
Variance for the noise distribution, attributed to the truncation values in BigGAN. |
|
Returns |
|
------- |
|
gen_samples: torch.FloatTensor |
|
Generated images. |
|
y_: torch.Tensor |
|
Sampled class labels. If using BigGAN backbone, y_.shape = [bs], |
|
if using StyleGAN2 backbone, y_.shape = [bs, c_dim], where `bs` is the batch size |
|
and `c_dim` is the dimensionality of the class embedding. |
|
feats_: torch.Tensor |
|
Sampled instance feature vectors, with shape [bs, h_dim], where `bs` is the batch size |
|
and `h_dim` is the dimensionality of the instance feature vectors. |
|
|
|
""" |
|
|
|
conditioning = sample_conditioning_func() |
|
|
|
with torch.no_grad(): |
|
if not class_cond and not instance_cond: |
|
z_ = conditioning |
|
y_, feats_ = None, None |
|
elif class_cond and not instance_cond: |
|
z_, y_ = conditioning |
|
y_ = y_.long() |
|
y_ = y_.to(device, non_blocking=True) |
|
feats_ = None |
|
elif instance_cond and not class_cond: |
|
z_, feats_ = conditioning |
|
feats_ = feats_.to(device, non_blocking=True) |
|
y_ = None |
|
elif instance_cond and class_cond: |
|
z_, y_, feats_ = conditioning |
|
y_, feats_ = ( |
|
y_.to(device, non_blocking=True), |
|
feats_.to(device, non_blocking=True), |
|
) |
|
z_ = z_.to(device, non_blocking=True) |
|
|
|
if backbone == "stylegan2": |
|
if y_ is None: |
|
y_ = torch.empty([z_.shape[0], generator.c_dim], device=device) |
|
else: |
|
y_ = torch.eye(config["n_classes"], device=device)[y_] |
|
if feats_ is None: |
|
feats_ = torch.empty([z_.shape[0], generator.h_dim], device=device) |
|
|
|
|
|
if backbone == "biggan": |
|
if config["parallel"]: |
|
gen_samples = nn.parallel.data_parallel(generator, (z_, y_, feats_)) |
|
else: |
|
gen_samples = generator(z_, y_, feats_) |
|
elif backbone == "stylegan2": |
|
gen_samples = generator( |
|
z=z_, |
|
c=y_, |
|
feats=feats_, |
|
truncation_psi=truncation_value, |
|
noise_mode="const", |
|
) |
|
return gen_samples, y_, feats_ |
|
|
|
|
|
def load_model_inference(config, device="cuda"): |
|
"""It loads the generator network to do inference with and over-rides the configuration file. |
|
|
|
Arguments |
|
--------- |
|
config: dict |
|
Dictionary with configuration parameters. |
|
device: str, optional |
|
Device name |
|
Returns |
|
------- |
|
generator: torch.nn.module |
|
Generator network. |
|
config: dict |
|
Overwritten configuration dictionary from the model checkpoint if it exists. |
|
|
|
""" |
|
if config["model_backbone"] == "biggan": |
|
|
|
best_fid = 1e5 |
|
best_name_final = "" |
|
for name_best in ["best0", "best1"]: |
|
try: |
|
root = "/".join([config["weights_root"], config["experiment_name"]]) |
|
state_dict_loaded = torch.load( |
|
"%s/%s.pth" |
|
% (root, biggan_utils.join_strings("_", ["state_dict", name_best])) |
|
) |
|
print( |
|
"For name best ", |
|
name_best, |
|
" we have an FID: ", |
|
state_dict_loaded["best_FID"], |
|
) |
|
if state_dict_loaded["best_FID"] < best_fid: |
|
best_fid = state_dict_loaded["best_FID"] |
|
best_name_final = name_best |
|
except: |
|
print("Checkpoint with name ", name_best, " not in folder.") |
|
config["load_weights"] = best_name_final |
|
print("Final name selected is ", best_name_final) |
|
|
|
|
|
state_dict = { |
|
"itr": 0, |
|
"epoch": 0, |
|
"save_num": 0, |
|
"save_best_num": 0, |
|
"best_IS": 0, |
|
"best_FID": 999999, |
|
"config": config, |
|
} |
|
|
|
biggan_utils.load_weights( |
|
None, |
|
None, |
|
state_dict, |
|
config["weights_root"], |
|
config["experiment_name"], |
|
config["load_weights"], |
|
None, |
|
strict=False, |
|
load_optim=False, |
|
eval=True, |
|
) |
|
|
|
|
|
for item in state_dict["config"]: |
|
if item not in [ |
|
"base_root", |
|
"data_root", |
|
"load_weights", |
|
"batch_size", |
|
"num_workers", |
|
"weights_root", |
|
"logs_root", |
|
"samples_root", |
|
"eval_reference_set", |
|
"eval_instance_set", |
|
"which_dataset", |
|
"seed", |
|
"eval_prdc", |
|
"use_balanced_sampler", |
|
"custom_distrib", |
|
"longtail_temperature", |
|
"longtail_gen", |
|
"num_inception_images", |
|
"sample_num_npz", |
|
"load_in_mem", |
|
"split", |
|
"z_var", |
|
"kmeans_subsampled", |
|
"filter_hd", |
|
"n_subsampled_data", |
|
"feature_augmentation", |
|
]: |
|
if item == "experiment_name" and config["experiment_name"] != "": |
|
continue |
|
config[item] = state_dict["config"][item] |
|
|
|
config["feature_augmentation"] = False |
|
config["hflips"] = False |
|
config["DA"] = False |
|
|
|
experiment_name = ( |
|
config["experiment_name"] |
|
if config["experiment_name"] |
|
else biggan_utils.name_from_config(config) |
|
) |
|
print("Experiment name is %s" % experiment_name) |
|
|
|
|
|
generator = BigGANModel.Generator(**config).to(device) |
|
|
|
|
|
print("Loading weights...") |
|
|
|
|
|
biggan_utils.load_weights( |
|
generator if not (config["use_ema"]) else None, |
|
None, |
|
state_dict, |
|
config["weights_root"], |
|
experiment_name, |
|
config["load_weights"], |
|
generator if config["ema"] and config["use_ema"] else None, |
|
strict=False, |
|
load_optim=False, |
|
) |
|
|
|
if config["G_eval_mode"]: |
|
print("Putting G in eval mode..") |
|
generator.eval() |
|
else: |
|
print("G is in %s mode..." % ("training" if generator.training else "eval")) |
|
|
|
elif config["model_backbone"] == "stylegan2": |
|
|
|
network_pkl = os.path.join( |
|
config["base_root"], config["experiment_name"], "best-network-snapshot.pkl" |
|
) |
|
print('Loading networks from "%s"...' % network_pkl) |
|
with dnnlib.util.open_url(network_pkl) as f: |
|
generator = legacy.load_network_pkl(f)["G_ema"].to(device) |
|
return generator, config |
|
|
|
|
|
def add_backbone_parser(parser): |
|
parser.add_argument( |
|
"--model_backbone", |
|
type=str, |
|
default="biggan", |
|
choices=["biggan", "stylegan2"], |
|
help="Backbone model? " "(default: %(default)s)", |
|
) |
|
return parser |
|
|