|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Utilities file |
|
This file contains utility functions for bookkeeping, logging, and data loading. |
|
Methods which directly affect training should either go in layers, the model, |
|
or train_fns.py. |
|
""" |
|
|
|
from __future__ import print_function |
|
import sys |
|
import os |
|
import numpy as np |
|
import time |
|
import datetime |
|
import json |
|
import pickle |
|
from argparse import ArgumentParser |
|
import random |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision |
|
import torchvision.transforms as transforms |
|
|
|
|
|
def prepare_parser(): |
|
usage = "Parser for all scripts." |
|
parser = ArgumentParser(description=usage) |
|
|
|
parser.add_argument( |
|
"--json_config", |
|
type=str, |
|
default="", |
|
help="Json config from where to load the configuration parameters.", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--resolution", |
|
type=int, |
|
default=64, |
|
help="Resolution to train with " "(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--augment", |
|
action="store_true", |
|
default=False, |
|
help="Augment with random crops and flips (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--num_workers", |
|
type=int, |
|
default=8, |
|
help="Number of dataloader workers; consider using less for HDF5 " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--no_pin_memory", |
|
action="store_false", |
|
dest="pin_memory", |
|
default=True, |
|
help="Pin data into memory through dataloader? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--shuffle", |
|
action="store_true", |
|
default=False, |
|
help="Shuffle the data (strongly recommended)? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--load_in_mem", |
|
action="store_true", |
|
default=False, |
|
help="Load all data into memory? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--use_multiepoch_sampler", |
|
action="store_true", |
|
default=False, |
|
help="Use the multi-epoch sampler for dataloader? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--use_checkpointable_sampler", |
|
action="store_true", |
|
default=False, |
|
help="Use the checkpointable sampler for dataloader? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--use_balanced_sampler", |
|
action="store_true", |
|
default=False, |
|
help="Use the class balanced sampler for dataloader? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--longtail_temperature", |
|
type=int, |
|
default=1, |
|
help="Temperature to relax longtail_distribution", |
|
) |
|
|
|
parser.add_argument( |
|
"--longtail", |
|
action="store_true", |
|
default=False, |
|
help="Use long-tail version of the dataset", |
|
) |
|
parser.add_argument( |
|
"--longtail_gen", |
|
action="store_true", |
|
default=False, |
|
help="Use long-tail version of class conditioning sampling for generator.", |
|
) |
|
parser.add_argument( |
|
"--custom_distrib_gen", |
|
action="store_true", |
|
default=False, |
|
help="Use custom distribution for sampling class conditionings in generator.", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--DiffAugment", type=str, default="", help="DiffAugment policy" |
|
) |
|
parser.add_argument( |
|
"--DA", |
|
action="store_true", |
|
default=False, |
|
help="Diff Augment for GANs (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--hflips", |
|
action="store_true", |
|
default=False, |
|
help="Use horizontal flips in data augmentation." "(default: %(default)s)", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--instance_cond", |
|
action="store_true", |
|
default=False, |
|
help="Use instance features as conditioning", |
|
) |
|
parser.add_argument( |
|
"--feature_augmentation", |
|
action="store_true", |
|
default=False, |
|
help="use hflips in instance conditionings (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--which_knn_balance", |
|
type=str, |
|
default="instance_balance", |
|
choices=["instance_balance", "nnclass_balance"], |
|
help="Class balancing either done at the instance level or at the class level.", |
|
) |
|
parser.add_argument( |
|
"--G_shared_feat", |
|
action="store_true", |
|
default=False, |
|
help="Use fully connected layer for conditioning instance features in G? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--shared_dim_feat", |
|
type=int, |
|
default=2048, |
|
help="G" |
|
"s fully connected layer output dimensionality for instance features" |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--k_nn", |
|
type=int, |
|
default=50, |
|
help="Number of neigbors for each instance" "(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--feature_extractor", |
|
type=str, |
|
default="classification", |
|
choices=["classification", "selfsupervised"], |
|
help="Choice of feature extractor", |
|
) |
|
parser.add_argument( |
|
"--backbone_feature_extractor", |
|
type=str, |
|
default="resnet50", |
|
choices=["resnet50"], |
|
help="Choice of feature extractor backbone", |
|
) |
|
|
|
parser.add_argument( |
|
"--eval_instance_set", |
|
type=str, |
|
default="train", |
|
help="(Eval) Dataset split from which to draw conditioning instances (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--kmeans_subsampled", |
|
type=int, |
|
default=-1, |
|
help="Number of kmeans centers if using subsampled training instances (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--n_subsampled_data", |
|
type=float, |
|
default=-1, |
|
help="Percent of instances used at test time", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--filter_hd", |
|
type=int, |
|
default=-1, |
|
help="Hamming distance to filter val test in COCO_Stuff (by default no filtering) (default: %(default)s)", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--model", |
|
type=str, |
|
default="BigGAN", |
|
help="Name of the model module (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--G_param", |
|
type=str, |
|
default="SN", |
|
help="Parameterization style to use for G, spectral norm (SN) or SVD (SVD)" |
|
" or None (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--D_param", |
|
type=str, |
|
default="SN", |
|
help="Parameterization style to use for D, spectral norm (SN) or SVD (SVD)" |
|
" or None (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--G_ch", |
|
type=int, |
|
default=64, |
|
help="Channel multiplier for G (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--D_ch", |
|
type=int, |
|
default=64, |
|
help="Channel multiplier for D (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--G_depth", |
|
type=int, |
|
default=1, |
|
help="Number of resblocks per stage in G? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--D_depth", |
|
type=int, |
|
default=1, |
|
help="Number of resblocks per stage in D? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--D_thin", |
|
action="store_false", |
|
dest="D_wide", |
|
default=True, |
|
help="Use the SN-GAN channel pattern for D? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--G_shared", |
|
action="store_true", |
|
default=True, |
|
help="Use shared embeddings in G? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--shared_dim", |
|
type=int, |
|
default=0, |
|
help="G" |
|
"s shared embedding dimensionality; if 0, will be equal to dim_z. " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--dim_z", type=int, default=120, help="Noise dimensionality: %(default)s)" |
|
) |
|
parser.add_argument( |
|
"--z_var", type=float, default=1.0, help="Noise variance: %(default)s)" |
|
) |
|
parser.add_argument( |
|
"--hier", |
|
action="store_true", |
|
default=False, |
|
help="Use hierarchical z in G? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--syncbn", |
|
action="store_true", |
|
default=False, |
|
help="Sync batch norm? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--cross_replica", |
|
action="store_true", |
|
default=False, |
|
help="Cross_replica batchnorm in G?(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--mybn", |
|
action="store_true", |
|
default=False, |
|
help="Use my batchnorm (which supports standing stats?) %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--G_nl", |
|
type=str, |
|
default="relu", |
|
help="Activation function for G (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--D_nl", |
|
type=str, |
|
default="relu", |
|
help="Activation function for D (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--G_attn", |
|
type=str, |
|
default="64", |
|
help="What resolutions to use attention on for G (underscore separated) " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--D_attn", |
|
type=str, |
|
default="64", |
|
help="What resolutions to use attention on for D (underscore separated) " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--norm_style", |
|
type=str, |
|
default="bn", |
|
help="Normalizer style for G, one of bn [batchnorm], in [instancenorm], " |
|
"ln [layernorm], gn [groupnorm] (default: %(default)s)", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--seed", |
|
type=int, |
|
default=0, |
|
help="Random seed to use; affects both initialization and " |
|
" dataloading. (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--G_init", |
|
type=str, |
|
default="ortho", |
|
help="Init style to use for G (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--D_init", |
|
type=str, |
|
default="ortho", |
|
help="Init style to use for D(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--skip_init", |
|
action="store_true", |
|
default=False, |
|
help="Skip initialization, ideal for testing when ortho init was used " |
|
"(default: %(default)s)", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--G_lr", |
|
type=float, |
|
default=5e-5, |
|
help="Learning rate to use for Generator (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--D_lr", |
|
type=float, |
|
default=2e-4, |
|
help="Learning rate to use for Discriminator (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--G_B1", |
|
type=float, |
|
default=0.0, |
|
help="Beta1 to use for Generator (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--D_B1", |
|
type=float, |
|
default=0.0, |
|
help="Beta1 to use for Discriminator (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--G_B2", |
|
type=float, |
|
default=0.999, |
|
help="Beta2 to use for Generator (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--D_B2", |
|
type=float, |
|
default=0.999, |
|
help="Beta2 to use for Discriminator (default: %(default)s)", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--batch_size", |
|
type=int, |
|
default=64, |
|
help="Default overall batchsize (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--G_batch_size", |
|
type=int, |
|
default=0, |
|
help="Batch size to use for G; if 0, same as D (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--num_G_accumulations", |
|
type=int, |
|
default=1, |
|
help="Number of passes to accumulate G" |
|
"s gradients over " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--num_D_steps", |
|
type=int, |
|
default=2, |
|
help="Number of D steps per G step (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--num_D_accumulations", |
|
type=int, |
|
default=1, |
|
help="Number of passes to accumulate D" |
|
"s gradients over " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--split_D", |
|
action="store_true", |
|
default=False, |
|
help="Run D twice rather than concatenating inputs? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--num_epochs", |
|
type=int, |
|
default=100, |
|
help="Number of epochs to train for (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--parallel", |
|
action="store_true", |
|
default=False, |
|
help="Train with multiple GPUs (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--G_fp16", |
|
action="store_true", |
|
default=False, |
|
help="Train with half-precision in G? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--D_fp16", |
|
action="store_true", |
|
default=False, |
|
help="Train with half-precision in D? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--D_mixed_precision", |
|
action="store_true", |
|
default=False, |
|
help="Train with half-precision activations but fp32 params in D? " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--G_mixed_precision", |
|
action="store_true", |
|
default=False, |
|
help="Train with half-precision activations but fp32 params in G? " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--accumulate_stats", |
|
action="store_true", |
|
default=False, |
|
help='Accumulate "standing" batchnorm stats? (default: %(default)s)', |
|
) |
|
parser.add_argument( |
|
"--num_standing_accumulations", |
|
type=int, |
|
default=16, |
|
help="Number of forward passes to use in accumulating standing stats? " |
|
"(default: %(default)s)", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--slurm_logdir", |
|
help="Where to save the logs from SLURM", |
|
required=False, |
|
default="biggan-training-runs", |
|
metavar="DIR", |
|
) |
|
|
|
parser.add_argument( |
|
"--G_eval_mode", |
|
action="store_true", |
|
default=False, |
|
help="Run G in eval mode (running/standing stats?) at sample/test time? " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--save_every", |
|
type=int, |
|
default=2000, |
|
help="Save every X iterations (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--num_save_copies", |
|
type=int, |
|
default=2, |
|
help="How many copies to save (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--num_best_copies", |
|
type=int, |
|
default=2, |
|
help="How many previous best checkpoints to save (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--which_best", |
|
type=str, |
|
default="IS", |
|
help='Which metric to use to determine when to save new "best"' |
|
"checkpoints, one of IS or FID (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--no_fid", |
|
action="store_true", |
|
default=False, |
|
help="Calculate IS only, not FID? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--test_every", |
|
type=int, |
|
default=5000, |
|
help="Test every X iterations (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--num_inception_images", |
|
type=int, |
|
default=50000, |
|
help="Number of samples to compute inception metrics with " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--hashname", |
|
action="store_true", |
|
default=False, |
|
help="Use a hash of the experiment name instead of the full config " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--base_root", |
|
type=str, |
|
default="", |
|
help="Default location to store all weights, samples, data, and logs " |
|
" (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--data_root", |
|
type=str, |
|
default="data", |
|
help="Default location where data is stored (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--weights_root", |
|
type=str, |
|
default="weights", |
|
help="Default location to store weights (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--logs_root", |
|
type=str, |
|
default="logs", |
|
help="Default location to store logs (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--samples_root", |
|
type=str, |
|
default="samples", |
|
help="Default location to store samples (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--pbar", |
|
type=str, |
|
default="mine", |
|
help='Type of progressbar to use; one of "mine" or "tqdm" ' |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--name_suffix", |
|
type=str, |
|
default="", |
|
help="Suffix for experiment name for loading weights for sampling " |
|
'(consider "best0") (default: %(default)s)', |
|
) |
|
parser.add_argument( |
|
"--experiment_name", |
|
type=str, |
|
default="", |
|
help="Optionally override the automatic experiment naming with this arg. " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--config_from_name", |
|
action="store_true", |
|
default=False, |
|
help="Use a hash of the experiment name instead of the full config " |
|
"(default: %(default)s)", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--ema", |
|
action="store_true", |
|
default=False, |
|
help="Keep an ema of G" "s weights? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--ema_decay", |
|
type=float, |
|
default=0.9999, |
|
help="EMA decay rate (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--use_ema", |
|
action="store_true", |
|
default=False, |
|
help="Use the EMA parameters of G for evaluation? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--ema_start", |
|
type=int, |
|
default=20000, |
|
help="When to start updating the EMA weights (default: %(default)s)", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--adam_eps", |
|
type=float, |
|
default=1e-6, |
|
help="epsilon value to use for Adam (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--BN_eps", |
|
type=float, |
|
default=1e-5, |
|
help="epsilon value to use for BatchNorm (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--SN_eps", |
|
type=float, |
|
default=1e-6, |
|
help="epsilon value to use for Spectral Norm(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--num_G_SVs", |
|
type=int, |
|
default=1, |
|
help="Number of SVs to track in G (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--num_D_SVs", |
|
type=int, |
|
default=1, |
|
help="Number of SVs to track in D (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--num_G_SV_itrs", |
|
type=int, |
|
default=1, |
|
help="Number of SV itrs in G (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--num_D_SV_itrs", |
|
type=int, |
|
default=1, |
|
help="Number of SV itrs in D (default: %(default)s)", |
|
) |
|
|
|
parser.add_argument( |
|
"--class_cond", |
|
action="store_true", |
|
default=False, |
|
help="Use classes as conditioning", |
|
) |
|
parser.add_argument( |
|
"--constant_conditioning", |
|
action="store_true", |
|
default=False, |
|
help="Use a a class-conditioning vector where the input label is always 0? (default: %(default)s)", |
|
) |
|
|
|
parser.add_argument( |
|
"--which_dataset", |
|
type=str, |
|
default="imagenet", |
|
|
|
help="Dataset choice.", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--G_ortho", |
|
type=float, |
|
default=0.0, |
|
help="Modified ortho reg coefficient in G(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--D_ortho", |
|
type=float, |
|
default=0.0, |
|
help="Modified ortho reg coefficient in D (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--toggle_grads", |
|
action="store_true", |
|
default=True, |
|
help="Toggle D and G" |
|
's "requires_grad" settings when not training them? ' |
|
" (default: %(default)s)", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--partition", |
|
help="Partition name for SLURM", |
|
required=False, |
|
default="learnlab", |
|
) |
|
parser.add_argument( |
|
"--which_train_fn", |
|
type=str, |
|
default="GAN", |
|
help="How2trainyourbois (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--run_setup", |
|
type=str, |
|
default="slurm", |
|
help="If local_debug or slurm (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--ddp_train", |
|
action="store_true", |
|
default=False, |
|
help="If use DDP for training", |
|
) |
|
parser.add_argument( |
|
"--n_nodes", |
|
type=int, |
|
default=1, |
|
help="Number of nodes for ddp (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--n_gpus_per_node", |
|
type=int, |
|
default=1, |
|
help="Number of gpus per node for ddp (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--stop_when_diverge", |
|
action="store_true", |
|
default=False, |
|
help="Stop the experiment if there is signs of divergence. " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--es_patience", type=int, default=50, help="Epochs for early stopping patience" |
|
) |
|
parser.add_argument( |
|
"--deterministic_run", |
|
action="store_true", |
|
default=False, |
|
help="Set deterministic cudnn and set the seed at each epoch" |
|
"(default: %(default)s)", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--eval_prdc", |
|
action="store_true", |
|
default=False, |
|
help="(Eval) Evaluate prdc " " (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--eval_reference_set", |
|
type=str, |
|
default="train", |
|
help="(Eval) Reference dataset to use for FID computation (default: %(default)s)", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--load_weights", |
|
type=str, |
|
default="", |
|
help="Suffix for which weights to load (e.g. best0, copy0) " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--resume", |
|
action="store_true", |
|
default=False, |
|
help="Resume training? (default: %(default)s)", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--logstyle", |
|
type=str, |
|
default="%3.3e", |
|
help="What style to use when logging training metrics?" |
|
"One of: %#.#f/ %#.#e (float/exp, text)," |
|
"pickle (python pickle)," |
|
"npz (numpy zip)," |
|
"mat (MATLAB .mat file) (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--log_G_spectra", |
|
action="store_true", |
|
default=False, |
|
help="Log the top 3 singular values in each SN layer in G? " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--log_D_spectra", |
|
action="store_true", |
|
default=False, |
|
help="Log the top 3 singular values in each SN layer in D? " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--sv_log_interval", |
|
type=int, |
|
default=10, |
|
help="Iteration interval for logging singular values " |
|
" (default: %(default)s)", |
|
) |
|
|
|
return parser |
|
|
|
|
|
|
|
def add_sample_parser(parser): |
|
parser.add_argument( |
|
"--sample_npz", |
|
action="store_true", |
|
default=False, |
|
help='Sample "sample_num_npz" images and save to npz? ' |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--sample_num_npz", |
|
type=int, |
|
default=50000, |
|
help="Number of images to sample when sampling NPZs " "(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--sample_sheets", |
|
action="store_true", |
|
default=False, |
|
help="Produce class-conditional sample sheets and stick them in " |
|
"the samples root? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--sample_interps", |
|
action="store_true", |
|
default=False, |
|
help="Produce interpolation sheets and stick them in " |
|
"the samples root? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--sample_sheet_folder_num", |
|
type=int, |
|
default=-1, |
|
help="Number to use for the folder for these sample sheets " |
|
"(default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--sample_random", |
|
action="store_true", |
|
default=False, |
|
help="Produce a single random sheet? (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--sample_trunc_curves", |
|
type=str, |
|
default="", |
|
help="Get inception metrics with a range of variances?" |
|
"To use this, specify a startpoint, step, and endpoint, e.g. " |
|
"--sample_trunc_curves 0.2_0.1_1.0 for a startpoint of 0.2, " |
|
"endpoint of 1.0, and stepsize of 1.0. Note that this is " |
|
"not exactly identical to using tf.truncated_normal, but should " |
|
"have approximately the same effect. (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--sample_inception_metrics", |
|
action="store_true", |
|
default=False, |
|
help="Calculate Inception metrics with sample.py? (default: %(default)s)", |
|
) |
|
return parser |
|
|
|
|
|
activation_dict = { |
|
"inplace_relu": nn.ReLU(inplace=True), |
|
"relu": nn.ReLU(inplace=False), |
|
"ir": nn.ReLU(inplace=True), |
|
} |
|
|
|
|
|
class CenterCropLongEdge(object): |
|
"""Crops the given PIL Image on the long edge. |
|
Args: |
|
size (sequence or int): Desired output size of the crop. If size is an |
|
int instead of sequence like (h, w), a square crop (size, size) is |
|
made. |
|
""" |
|
|
|
def __call__(self, img): |
|
""" |
|
Args: |
|
img (PIL Image): Image to be cropped. |
|
Returns: |
|
PIL Image: Cropped image. |
|
""" |
|
return transforms.functional.center_crop(img, min(img.size)) |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ |
|
|
|
|
|
class RandomCropLongEdge(object): |
|
"""Crops the given PIL Image on the long edge with a random start point. |
|
Args: |
|
size (sequence or int): Desired output size of the crop. If size is an |
|
int instead of sequence like (h, w), a square crop (size, size) is |
|
made. |
|
""" |
|
|
|
def __call__(self, img): |
|
""" |
|
Args: |
|
img (PIL Image): Image to be cropped. |
|
Returns: |
|
PIL Image: Cropped image. |
|
""" |
|
size = (min(img.size), min(img.size)) |
|
|
|
i = ( |
|
0 |
|
if size[0] == img.size[0] |
|
else np.random.randint(low=0, high=img.size[0] - size[0]) |
|
) |
|
j = ( |
|
0 |
|
if size[1] == img.size[1] |
|
else np.random.randint(low=0, high=img.size[1] - size[1]) |
|
) |
|
return transforms.functional.crop(img, i, j, size[0], size[1]) |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ |
|
|
|
|
|
|
|
def seed_rng(seed): |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
np.random.seed(seed) |
|
|
|
|
|
def seed_worker(worker_id): |
|
worker_seed = torch.initial_seed() + worker_id |
|
|
|
|
|
|
|
|
|
def update_config_roots(config, change_weight_folder=True): |
|
if config["base_root"]: |
|
print("Pegging all root folders to base root %s" % config["base_root"]) |
|
for key in ["weights", "logs", "samples"]: |
|
if change_weight_folder: |
|
config["%s_root" % key] = "%s/%s" % (config["base_root"], key) |
|
else: |
|
config["%s_root" % key] = "%s" % (config["base_root"]) |
|
return config |
|
|
|
|
|
|
|
def prepare_root(config): |
|
for key in ["weights_root", "logs_root", "samples_root"]: |
|
if not os.path.exists(config[key]): |
|
print("Making directory %s for %s..." % (config[key], key)) |
|
os.mkdir(config[key]) |
|
|
|
|
|
|
|
|
|
|
|
class ema(object): |
|
def __init__(self, source, target, decay=0.9999, start_itr=0): |
|
self.source = source |
|
self.target = target |
|
self.decay = decay |
|
|
|
self.start_itr = start_itr |
|
|
|
self.source_dict = self.source.state_dict() |
|
self.target_dict = self.target.state_dict() |
|
print("Initializing EMA parameters to be source parameters...") |
|
with torch.no_grad(): |
|
for key in self.source_dict: |
|
self.target_dict[key].data.copy_(self.source_dict[key].data) |
|
|
|
|
|
def update(self, itr=None): |
|
|
|
|
|
if itr and itr < self.start_itr: |
|
decay = 0.0 |
|
else: |
|
decay = self.decay |
|
with torch.no_grad(): |
|
for key in self.source_dict: |
|
self.target_dict[key].data.copy_( |
|
self.target_dict[key].data * decay |
|
+ self.source_dict[key].data * (1 - decay) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def ortho(model, strength=1e-4, blacklist=[]): |
|
with torch.no_grad(): |
|
for param in model.parameters(): |
|
|
|
if len(param.shape) < 2 or any([param is item for item in blacklist]): |
|
continue |
|
w = param.view(param.shape[0], -1) |
|
grad = 2 * torch.mm( |
|
torch.mm(w, w.t()) * (1.0 - torch.eye(w.shape[0], device=w.device)), w |
|
) |
|
param.grad.data += strength * grad.view(param.shape) |
|
|
|
|
|
|
|
|
|
|
|
def default_ortho(model, strength=1e-4, blacklist=[]): |
|
with torch.no_grad(): |
|
for param in model.parameters(): |
|
|
|
if len(param.shape) < 2 or param in blacklist: |
|
continue |
|
w = param.view(param.shape[0], -1) |
|
grad = 2 * torch.mm( |
|
torch.mm(w, w.t()) - torch.eye(w.shape[0], device=w.device), w |
|
) |
|
param.grad.data += strength * grad.view(param.shape) |
|
|
|
|
|
|
|
def toggle_grad(model, on_or_off): |
|
for param in model.parameters(): |
|
param.requires_grad = on_or_off |
|
|
|
|
|
|
|
|
|
|
|
def join_strings(base_string, strings): |
|
return base_string.join([item for item in strings if item]) |
|
|
|
|
|
|
|
def save_weights( |
|
G, |
|
D, |
|
state_dict, |
|
weights_root, |
|
experiment_name, |
|
name_suffix=None, |
|
G_ema=None, |
|
embedded_optimizers=True, |
|
G_optim=None, |
|
D_optim=None, |
|
): |
|
root = "/".join([weights_root, experiment_name]) |
|
if not os.path.exists(root): |
|
os.mkdir(root) |
|
if name_suffix: |
|
print("Saving weights to %s/%s..." % (root, name_suffix)) |
|
else: |
|
print("Saving weights to %s..." % root) |
|
torch.save( |
|
G.state_dict(), "%s/%s.pth" % (root, join_strings("_", ["G", name_suffix])) |
|
) |
|
torch.save( |
|
D.state_dict(), "%s/%s.pth" % (root, join_strings("_", ["D", name_suffix])) |
|
) |
|
torch.save( |
|
state_dict, "%s/%s.pth" % (root, join_strings("_", ["state_dict", name_suffix])) |
|
) |
|
|
|
if embedded_optimizers: |
|
torch.save( |
|
G.optim.state_dict(), |
|
"%s/%s.pth" % (root, join_strings("_", ["G_optim", name_suffix])), |
|
) |
|
torch.save( |
|
D.optim.state_dict(), |
|
"%s/%s.pth" % (root, join_strings("_", ["D_optim", name_suffix])), |
|
) |
|
else: |
|
torch.save( |
|
G_optim.state_dict(), |
|
"%s/%s.pth" % (root, join_strings("_", ["G_optim", name_suffix])), |
|
) |
|
torch.save( |
|
D_optim.state_dict(), |
|
"%s/%s.pth" % (root, join_strings("_", ["D_optim", name_suffix])), |
|
) |
|
if G_ema is not None: |
|
torch.save( |
|
G_ema.state_dict(), |
|
"%s/%s.pth" % (root, join_strings("_", ["G_ema", name_suffix])), |
|
) |
|
|
|
|
|
|
|
def load_weights( |
|
G, |
|
D, |
|
state_dict, |
|
weights_root, |
|
experiment_name, |
|
name_suffix=None, |
|
G_ema=None, |
|
strict=True, |
|
load_optim=True, |
|
eval=False, |
|
map_location=None, |
|
embedded_optimizers=True, |
|
G_optim=None, |
|
D_optim=None, |
|
): |
|
root = "/".join([weights_root, experiment_name]) |
|
if not os.path.exists(root): |
|
print("Not loading data, experiment folder does not exist yet!") |
|
print(root) |
|
if eval: |
|
raise ValueError("Make sure foder exists") |
|
return |
|
|
|
if name_suffix: |
|
print("Loading %s weights from %s..." % (name_suffix, root)) |
|
else: |
|
print("Loading weights from %s..." % root) |
|
if G is not None: |
|
G.load_state_dict( |
|
torch.load( |
|
"%s/%s.pth" % (root, join_strings("_", ["G", name_suffix])), |
|
map_location=map_location, |
|
), |
|
strict=strict, |
|
) |
|
if load_optim: |
|
if embedded_optimizers: |
|
G.optim.load_state_dict( |
|
torch.load( |
|
"%s/%s.pth" |
|
% (root, join_strings("_", ["G_optim", name_suffix])), |
|
map_location=map_location, |
|
) |
|
) |
|
else: |
|
G_optim.load_state_dict( |
|
torch.load( |
|
"%s/%s.pth" |
|
% (root, join_strings("_", ["G_optim", name_suffix])), |
|
map_location=map_location, |
|
) |
|
) |
|
if D is not None: |
|
D.load_state_dict( |
|
torch.load( |
|
"%s/%s.pth" % (root, join_strings("_", ["D", name_suffix])), |
|
map_location=map_location, |
|
), |
|
strict=strict, |
|
) |
|
if load_optim: |
|
if embedded_optimizers: |
|
D.optim.load_state_dict( |
|
torch.load( |
|
"%s/%s.pth" |
|
% (root, join_strings("_", ["D_optim", name_suffix])), |
|
map_location=map_location, |
|
) |
|
) |
|
else: |
|
D_optim.load_state_dict( |
|
torch.load( |
|
"%s/%s.pth" |
|
% (root, join_strings("_", ["D_optim", name_suffix])), |
|
map_location=map_location, |
|
) |
|
) |
|
|
|
for item in state_dict: |
|
try: |
|
state_dict[item] = torch.load( |
|
"%s/%s.pth" % (root, join_strings("_", ["state_dict", name_suffix])), |
|
map_location=map_location, |
|
)[item] |
|
except: |
|
print("No values to load") |
|
if G_ema is not None: |
|
G_ema.load_state_dict( |
|
torch.load( |
|
"%s/%s.pth" % (root, join_strings("_", ["G_ema", name_suffix])), |
|
map_location=map_location, |
|
), |
|
strict=strict, |
|
) |
|
|
|
|
|
""" MetricsLogger originally stolen from VoxNet source code. |
|
Used for logging inception metrics""" |
|
|
|
|
|
class MetricsLogger(object): |
|
def __init__(self, fname, reinitialize=False): |
|
self.fname = fname |
|
self.reinitialize = reinitialize |
|
if os.path.exists(self.fname): |
|
if self.reinitialize: |
|
print("{} exists, deleting...".format(self.fname)) |
|
os.remove(self.fname) |
|
|
|
def log(self, record=None, **kwargs): |
|
""" |
|
Assumption: no newlines in the input. |
|
""" |
|
if record is None: |
|
record = {} |
|
record.update(kwargs) |
|
record["_stamp"] = time.time() |
|
with open(self.fname, "a") as f: |
|
f.write(json.dumps(record, ensure_ascii=True) + "\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MyLogger(object): |
|
def __init__(self, fname, reinitialize=False, logstyle="%3.3f"): |
|
self.root = fname |
|
if not os.path.exists(self.root): |
|
os.mkdir(self.root) |
|
self.reinitialize = reinitialize |
|
self.metrics = [] |
|
self.logstyle = logstyle |
|
|
|
|
|
def reinit(self, item): |
|
if os.path.exists("%s/%s.log" % (self.root, item)): |
|
if self.reinitialize: |
|
|
|
if "sv" in item: |
|
if not any("sv" in item for item in self.metrics): |
|
print("Deleting singular value logs...") |
|
else: |
|
print( |
|
"{} exists, deleting...".format("%s_%s.log" % (self.root, item)) |
|
) |
|
os.remove("%s/%s.log" % (self.root, item)) |
|
|
|
|
|
def log(self, itr, **kwargs): |
|
for arg in kwargs: |
|
if arg not in self.metrics: |
|
if self.reinitialize: |
|
self.reinit(arg) |
|
self.metrics += [arg] |
|
if self.logstyle == "pickle": |
|
print("Pickle not currently supported...") |
|
|
|
|
|
elif self.logstyle == "mat": |
|
print(".mat logstyle not currently supported...") |
|
else: |
|
with open("%s/%s.log" % (self.root, arg), "a") as f: |
|
f.write("%d: %s\n" % (itr, self.logstyle % kwargs[arg])) |
|
|
|
|
|
|
|
def write_metadata(logs_root, experiment_name, config, state_dict): |
|
with open(("%s/%s/metalog.txt" % (logs_root, experiment_name)), "w") as writefile: |
|
writefile.write("datetime: %s\n" % str(datetime.datetime.now())) |
|
writefile.write("config: %s\n" % str(config)) |
|
writefile.write("state: %s\n" % str(state_dict)) |
|
|
|
|
|
""" |
|
Very basic progress indicator to wrap an iterable in. |
|
|
|
Author: Jan Schlüter |
|
Andy's adds: time elapsed in addition to ETA, makes it possible to add |
|
estimated time to 1k iters instead of estimated time to completion. |
|
""" |
|
|
|
|
|
def progress(items, desc="", total=None, min_delay=0.1, displaytype="s1k"): |
|
""" |
|
Returns a generator over `items`, printing the number and percentage of |
|
items processed and the estimated remaining processing time before yielding |
|
the next item. `total` gives the total number of items (required if `items` |
|
has no length), and `min_delay` gives the minimum time in seconds between |
|
subsequent prints. `desc` gives an optional prefix text (end with a space). |
|
""" |
|
total = total or len(items) |
|
t_start = time.time() |
|
t_last = 0 |
|
for n, item in enumerate(items): |
|
t_now = time.time() |
|
if t_now - t_last > min_delay: |
|
print( |
|
"\r%s%d/%d (%6.2f%%)" % (desc, n + 1, total, n / float(total) * 100), |
|
end=" ", |
|
) |
|
if n > 0: |
|
|
|
if displaytype == "s1k": |
|
next_1000 = n + (1000 - n % 1000) |
|
t_done = t_now - t_start |
|
t_1k = t_done / n * next_1000 |
|
outlist = list(divmod(t_done, 60)) + list(divmod(t_1k - t_done, 60)) |
|
print("(TE/ET1k: %d:%02d / %d:%02d)" % tuple(outlist), end=" ") |
|
else: |
|
t_done = t_now - t_start |
|
t_total = t_done / n * total |
|
outlist = list(divmod(t_done, 60)) + list( |
|
divmod(t_total - t_done, 60) |
|
) |
|
print("(TE/ETA: %d:%02d / %d:%02d)" % tuple(outlist), end=" ") |
|
|
|
sys.stdout.flush() |
|
t_last = t_now |
|
yield item |
|
t_total = time.time() - t_start |
|
print( |
|
"\r%s%d/%d (100.00%%) (took %d:%02d)" |
|
% ((desc, total, total) + divmod(t_total, 60)) |
|
) |
|
|
|
|
|
|
|
def sample( |
|
G, |
|
sample_conditioning_func, |
|
config, |
|
class_cond=True, |
|
instance_cond=False, |
|
device="cuda", |
|
): |
|
conditioning = sample_conditioning_func() |
|
with torch.no_grad(): |
|
if class_cond and not instance_cond: |
|
z_, y_ = conditioning |
|
y_ = y_.long() |
|
y_ = y_.to(device, non_blocking=True) |
|
feats_ = None |
|
elif instance_cond and not class_cond: |
|
z_, feats_ = conditioning |
|
feats_ = feats_.to(device, non_blocking=True) |
|
y_ = None |
|
elif instance_cond and class_cond: |
|
z_, y_, feats_ = conditioning |
|
y_, feats_ = ( |
|
y_.to(device, non_blocking=True), |
|
feats_.to(device, non_blocking=True), |
|
) |
|
z_ = z_.to(device, non_blocking=True) |
|
|
|
if config["parallel"]: |
|
G_z = nn.parallel.data_parallel(G, (z_, y_, feats_)) |
|
else: |
|
G_z = G(z_, y_, feats_) |
|
return G_z, y_, feats_ |
|
|
|
|
|
|
|
def sample_sheet( |
|
G, |
|
classes_per_sheet, |
|
num_classes, |
|
samples_per_class, |
|
parallel, |
|
samples_root, |
|
experiment_name, |
|
folder_number, |
|
z_=None, |
|
): |
|
|
|
if not os.path.isdir("%s/%s" % (samples_root, experiment_name)): |
|
os.mkdir("%s/%s" % (samples_root, experiment_name)) |
|
if not os.path.isdir("%s/%s/%d" % (samples_root, experiment_name, folder_number)): |
|
os.mkdir("%s/%s/%d" % (samples_root, experiment_name, folder_number)) |
|
|
|
for i in range(num_classes // classes_per_sheet): |
|
ims = [] |
|
y = torch.arange( |
|
i * classes_per_sheet, (i + 1) * classes_per_sheet, device="cuda" |
|
) |
|
for j in range(samples_per_class): |
|
if ( |
|
(z_ is not None) |
|
and hasattr(z_, "sample_") |
|
and classes_per_sheet <= z_.size(0) |
|
): |
|
z_.sample_() |
|
else: |
|
z_ = torch.randn(classes_per_sheet, G.dim_z, device="cuda") |
|
with torch.no_grad(): |
|
if parallel: |
|
o = nn.parallel.data_parallel( |
|
G, (z_[:classes_per_sheet], G.shared(y)) |
|
) |
|
else: |
|
o = G(z_[:classes_per_sheet], G.shared(y)) |
|
|
|
ims += [o.data.cpu()] |
|
|
|
out_ims = ( |
|
torch.stack(ims, 1) |
|
.view(-1, ims[0].shape[1], ims[0].shape[2], ims[0].shape[3]) |
|
.data.float() |
|
.cpu() |
|
) |
|
|
|
image_filename = "%s/%s/%d/samples%d.jpg" % ( |
|
samples_root, |
|
experiment_name, |
|
folder_number, |
|
i, |
|
) |
|
torchvision.utils.save_image( |
|
out_ims, image_filename, nrow=samples_per_class, normalize=True |
|
) |
|
|
|
|
|
|
|
def interp(x0, x1, num_midpoints): |
|
lerp = torch.linspace(0, 1.0, num_midpoints + 2, device="cuda").to(x0.dtype) |
|
return (x0 * (1 - lerp.view(1, -1, 1))) + (x1 * lerp.view(1, -1, 1)) |
|
|
|
|
|
|
|
|
|
def interp_sheet( |
|
G, |
|
num_per_sheet, |
|
num_midpoints, |
|
num_classes, |
|
parallel, |
|
samples_root, |
|
experiment_name, |
|
folder_number, |
|
sheet_number=0, |
|
fix_z=False, |
|
fix_y=False, |
|
device="cuda", |
|
): |
|
|
|
if fix_z: |
|
zs = torch.randn(num_per_sheet, 1, G.dim_z, device=device) |
|
zs = zs.repeat(1, num_midpoints + 2, 1).view(-1, G.dim_z) |
|
else: |
|
zs = interp( |
|
torch.randn(num_per_sheet, 1, G.dim_z, device=device), |
|
torch.randn(num_per_sheet, 1, G.dim_z, device=device), |
|
num_midpoints, |
|
).view(-1, G.dim_z) |
|
if fix_y: |
|
ys = sample_1hot(num_per_sheet, num_classes) |
|
ys = G.shared(ys).view(num_per_sheet, 1, -1) |
|
ys = ys.repeat(1, num_midpoints + 2, 1).view( |
|
num_per_sheet * (num_midpoints + 2), -1 |
|
) |
|
else: |
|
ys = interp( |
|
G.shared(sample_1hot(num_per_sheet, num_classes)).view( |
|
num_per_sheet, 1, -1 |
|
), |
|
G.shared(sample_1hot(num_per_sheet, num_classes)).view( |
|
num_per_sheet, 1, -1 |
|
), |
|
num_midpoints, |
|
).view(num_per_sheet * (num_midpoints + 2), -1) |
|
|
|
if G.fp16: |
|
zs = zs.half() |
|
with torch.no_grad(): |
|
if parallel: |
|
out_ims = nn.parallel.data_parallel(G, (zs, ys)).data.cpu() |
|
else: |
|
out_ims = G(zs, ys).data.cpu() |
|
interp_style = "" + ("Z" if not fix_z else "") + ("Y" if not fix_y else "") |
|
image_filename = "%s/%s/%d/interp%s%d.jpg" % ( |
|
samples_root, |
|
experiment_name, |
|
folder_number, |
|
interp_style, |
|
sheet_number, |
|
) |
|
torchvision.utils.save_image( |
|
out_ims, image_filename, nrow=num_midpoints + 2, normalize=True |
|
) |
|
|
|
|
|
|
|
|
|
def print_grad_norms(net): |
|
gradsums = [ |
|
[ |
|
float(torch.norm(param.grad).item()), |
|
float(torch.norm(param).item()), |
|
param.shape, |
|
] |
|
for param in net.parameters() |
|
] |
|
order = np.argsort([item[0] for item in gradsums]) |
|
print( |
|
[ |
|
"%3.3e,%3.3e, %s" |
|
% ( |
|
gradsums[item_index][0], |
|
gradsums[item_index][1], |
|
str(gradsums[item_index][2]), |
|
) |
|
for item_index in order |
|
] |
|
) |
|
|
|
|
|
|
|
|
|
def get_SVs(net, prefix): |
|
d = net.state_dict() |
|
return { |
|
("%s_%s" % (prefix, key)).replace(".", "_"): float(d[key].item()) |
|
for key in d |
|
if "sv" in key |
|
} |
|
|
|
|
|
|
|
def name_from_config(config): |
|
name = "_".join( |
|
[ |
|
item |
|
for item in [ |
|
"Big%s" % config["which_train_fn"], |
|
config["dataset"], |
|
config["model"] if config["model"] != "BigGAN" else None, |
|
"seed%d" % config["seed"], |
|
"Gch%d" % config["G_ch"], |
|
"Dch%d" % config["D_ch"], |
|
"Gd%d" % config["G_depth"] if config["G_depth"] > 1 else None, |
|
"Dd%d" % config["D_depth"] if config["D_depth"] > 1 else None, |
|
"bs%d" % config["batch_size"], |
|
"Gfp16" if config["G_fp16"] else None, |
|
"Dfp16" if config["D_fp16"] else None, |
|
"nDs%d" % config["num_D_steps"] if config["num_D_steps"] > 1 else None, |
|
"nDa%d" % config["num_D_accumulations"] |
|
if config["num_D_accumulations"] > 1 |
|
else None, |
|
"nGa%d" % config["num_G_accumulations"] |
|
if config["num_G_accumulations"] > 1 |
|
else None, |
|
"Glr%2.1e" % config["G_lr"], |
|
"Dlr%2.1e" % config["D_lr"], |
|
"GB%3.3f" % config["G_B1"] if config["G_B1"] != 0.0 else None, |
|
"GBB%3.3f" % config["G_B2"] if config["G_B2"] != 0.999 else None, |
|
"DB%3.3f" % config["D_B1"] if config["D_B1"] != 0.0 else None, |
|
"DBB%3.3f" % config["D_B2"] if config["D_B2"] != 0.999 else None, |
|
"Gnl%s" % config["G_nl"], |
|
"Dnl%s" % config["D_nl"], |
|
"Ginit%s" % config["G_init"], |
|
"Dinit%s" % config["D_init"], |
|
"G%s" % config["G_param"] if config["G_param"] != "SN" else None, |
|
"D%s" % config["D_param"] if config["D_param"] != "SN" else None, |
|
"Gattn%s" % config["G_attn"] if config["G_attn"] != "0" else None, |
|
"Dattn%s" % config["D_attn"] if config["D_attn"] != "0" else None, |
|
"Gortho%2.1e" % config["G_ortho"] if config["G_ortho"] > 0.0 else None, |
|
"Dortho%2.1e" % config["D_ortho"] if config["D_ortho"] > 0.0 else None, |
|
config["norm_style"] if config["norm_style"] != "bn" else None, |
|
"cr" if config["cross_replica"] else None, |
|
"Gshared" if config["G_shared"] else None, |
|
"hier" if config["hier"] else None, |
|
"ema" if config["ema"] else None, |
|
config["name_suffix"] if config["name_suffix"] else None, |
|
] |
|
if item is not None |
|
] |
|
) |
|
|
|
|
|
|
|
def query_gpu(indices): |
|
os.system("nvidia-smi -i 0 --query-gpu=memory.free --format=csv") |
|
|
|
|
|
|
|
def count_parameters(module): |
|
print( |
|
"Number of parameters: {}".format( |
|
sum([p.data.nelement() for p in module.parameters()]) |
|
) |
|
) |
|
|
|
|
|
|
|
def sample_1hot(batch_size, num_classes, device="cuda"): |
|
return torch.randint( |
|
low=0, |
|
high=num_classes, |
|
size=(batch_size,), |
|
device=device, |
|
dtype=torch.int64, |
|
requires_grad=False, |
|
) |
|
|
|
|
|
def initiate_standing_stats(net): |
|
for module in net.modules(): |
|
if hasattr(module, "accumulate_standing"): |
|
module.reset_stats() |
|
module.accumulate_standing = True |
|
|
|
|
|
def accumulate_standing_stats(net, z, y, nclasses, num_accumulations=16): |
|
initiate_standing_stats(net) |
|
net.train() |
|
for i in range(num_accumulations): |
|
with torch.no_grad(): |
|
z.normal_() |
|
y.random_(0, nclasses) |
|
x = net(z, net.shared(y)) |
|
|
|
net.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from torch.optim.optimizer import Optimizer |
|
|
|
|
|
class Adam16(Optimizer): |
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): |
|
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) |
|
params = list(params) |
|
super(Adam16, self).__init__(params, defaults) |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
super(Adam16, self).load_state_dict(state_dict) |
|
for group in self.param_groups: |
|
for p in group["params"]: |
|
self.state[p]["exp_avg"] = self.state[p]["exp_avg"].float() |
|
self.state[p]["exp_avg_sq"] = self.state[p]["exp_avg_sq"].float() |
|
self.state[p]["fp32_p"] = self.state[p]["fp32_p"].float() |
|
|
|
def step(self, closure=None): |
|
"""Performs a single optimization step. |
|
Arguments: |
|
closure (callable, optional): A closure that reevaluates the model |
|
and returns the loss. |
|
""" |
|
loss = None |
|
if closure is not None: |
|
loss = closure() |
|
|
|
for group in self.param_groups: |
|
for p in group["params"]: |
|
if p.grad is None: |
|
continue |
|
|
|
grad = p.grad.data.float() |
|
state = self.state[p] |
|
|
|
|
|
if len(state) == 0: |
|
state["step"] = 0 |
|
|
|
state["exp_avg"] = grad.new().resize_as_(grad).zero_() |
|
|
|
state["exp_avg_sq"] = grad.new().resize_as_(grad).zero_() |
|
|
|
state["fp32_p"] = p.data.float() |
|
|
|
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] |
|
beta1, beta2 = group["betas"] |
|
|
|
state["step"] += 1 |
|
|
|
if group["weight_decay"] != 0: |
|
grad = grad.add(group["weight_decay"], state["fp32_p"]) |
|
|
|
|
|
exp_avg.mul_(beta1).add_(1 - beta1, grad) |
|
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) |
|
|
|
denom = exp_avg_sq.sqrt().add_(group["eps"]) |
|
|
|
bias_correction1 = 1 - beta1 ** state["step"] |
|
bias_correction2 = 1 - beta2 ** state["step"] |
|
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 |
|
|
|
state["fp32_p"].addcdiv_(-step_size, exp_avg, denom) |
|
p.data = state["fp32_p"].half() |
|
|
|
return loss |
|
|