ic_gan / BigGAN_PyTorch /trainer.py
ArantxaCasanova
First model version
a00ee36
raw history blame
No virus
18.2 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# All contributions by Andy Brock:
# Copyright (c) 2019 Andy Brock
#
# MIT License
import os
import functools
import math
from tqdm import tqdm, trange
import argparse
import time
import subprocess
import re
import sys
sys.path.insert(1, os.path.join(sys.path[0], ".."))
import numpy as np
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
# Import my stuff
import data_utils.inception_utils as inception_utils
import utils
import train_fns
from sync_batchnorm import patch_replication_callback
from data_utils import utils as data_utils
def run(config, ddp_setup="slurm", master_node=""):
config["n_classes"] = 1000 # utils.nclass_dict[self.config['dataset']]
config["G_activation"] = utils.activation_dict[config["G_nl"]]
config["D_activation"] = utils.activation_dict[config["D_nl"]]
config = utils.update_config_roots(config)
# Prepare root folders if necessary
utils.prepare_root(config)
if config["ddp_train"]:
if ddp_setup == "slurm":
n_nodes = int(os.environ.get("SLURM_JOB_NUM_NODES"))
n_gpus_per_node = int(os.environ.get("SLURM_TASKS_PER_NODE").split("(")[0])
world_size = n_gpus_per_node * n_nodes
print(
"Master node is ",
master_node,
" World size is ",
world_size,
" with ",
n_gpus_per_node,
"gpus per node.",
)
dist_url = "tcp://"
dist_url += master_node
port = 40000
dist_url += ":" + str(port)
print("Dist url ", dist_url)
train(-1, world_size, config, dist_url)
else:
world_size = torch.cuda.device_count()
dist_url = "env://"
mp.spawn(
train, args=(world_size, config, dist_url), nprocs=world_size, join=True
)
else:
train(0, -1, config, None)
def train(rank, world_size, config, dist_url):
print("Rank of this job is ", rank)
copy_locally = False
tmp_dir = ""
if config["ddp_train"]:
if dist_url == "env://":
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
local_rank = rank
else:
rank = int(os.environ.get("SLURM_PROCID"))
local_rank = int(os.environ.get("SLURM_LOCALID"))
copy_locally = True
tmp_dir = "/scratch/slurm_tmpdir/" + str(os.environ.get("SLURM_JOB_ID"))
print("Before setting process group")
print(dist_url, rank)
dist.init_process_group(
backend="nccl", init_method=dist_url, rank=rank, world_size=world_size
)
print("After setting process group")
device = "cuda:{}".format(local_rank) # rank % 8)
print(dist_url, rank, " /Device is ", device)
else:
device = "cuda"
local_rank = "cuda"
# Update the config dict as necessary
# This is for convenience, to add settings derived from the user-specified
# configuration into the config-dict (e.g. inferring the number of classes
# and size of the images from the dataset, passing in a pytorch object
# for the activation specified as a string)'
# Seed RNG
utils.seed_rng(config["seed"] + rank)
# Setup cudnn.benchmark for free speed
torch.backends.cudnn.benchmark = True
if config["deterministic_run"]:
torch.backends.cudnn.deterministic = True
# Import the model--this line allows us to dynamically select different files.
model = __import__(config["model"])
experiment_name = (
config["experiment_name"]
if config["experiment_name"]
else utils.name_from_config(config)
)
print("Experiment name is %s" % experiment_name)
if config["ddp_train"]:
torch.cuda.set_device(device)
# Next, build the model
G = model.Generator(**{**config, "embedded_optimizers": False}).to(device)
D = model.Discriminator(**{**config, "embedded_optimizers": False}).to(device)
# If using EMA, prepare it
if config["ema"]:
print("Preparing EMA for G with decay of {}".format(config["ema_decay"]))
G_ema = model.Generator(**{**config, "skip_init": True, "no_optim": True}).to(
device
)
ema = utils.ema(G, G_ema, config["ema_decay"], config["ema_start"])
else:
G_ema, ema = None, None
print(
"Number of params in G: {} D: {}".format(
*[sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]]
)
)
# Setup the optimizers
if config["D_fp16"]:
print("Using fp16 adam ")
optim_type = utils.Adam16
else:
optim_type = optim.Adam
optimizer_D = optim_type(
params=D.parameters(),
lr=config["D_lr"],
betas=(config["D_B1"], config["D_B2"]),
weight_decay=0,
eps=config["adam_eps"],
)
optimizer_G = optim_type(
params=G.parameters(),
lr=config["G_lr"],
betas=(config["G_B1"], config["G_B2"]),
weight_decay=0,
eps=config["adam_eps"],
)
# Prepare state dict, which holds things like epoch # and itr #
state_dict = {
"itr": 0,
"epoch": 0,
"save_num": 0,
"save_best_num": 0,
"best_IS": 0,
"best_FID": 999999,
"es_epoch": 0,
"config": config,
}
# FP16?
if config["G_fp16"]:
print("Casting G to float16...")
G = G.half()
if config["ema"]:
G_ema = G_ema.half()
if config["D_fp16"]:
print("Casting D to fp16...")
D = D.half()
## DDP the models
if config["ddp_train"]:
print("before G DDP ")
G = DDP(
G,
device_ids=[local_rank],
output_device=local_rank,
find_unused_parameters=True,
)
print("After G DDP ")
D = DDP(
D,
device_ids=[local_rank],
output_device=local_rank,
find_unused_parameters=True,
)
# If loading from a pre-trained model, load weights
print("Loading weights...")
if config["ddp_train"]:
dist.barrier()
map_location = device
else:
map_location = None
utils.load_weights(
G,
D,
state_dict,
config["weights_root"],
experiment_name,
config["load_weights"] if config["load_weights"] else None,
G_ema if config["ema"] else None,
map_location=map_location,
embedded_optimizers=False,
G_optim=optimizer_G,
D_optim=optimizer_D,
)
# wrapper class
GD = model.G_D(G, D, optimizer_G=optimizer_G, optimizer_D=optimizer_D)
if config["parallel"] and world_size > -1:
GD = nn.DataParallel(GD)
if config["cross_replica"]:
patch_replication_callback(GD)
# Prepare loggers for stats; metrics holds test metrics,
# lmetrics holds any desired training metrics.
if rank == 0:
test_metrics_fname = "%s/%s_log.jsonl" % (config["logs_root"], experiment_name)
train_metrics_fname = "%s/%s" % (config["logs_root"], experiment_name)
print("Inception Metrics will be saved to {}".format(test_metrics_fname))
test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=False)
print("Training Metrics will be saved to {}".format(train_metrics_fname))
train_log = utils.MyLogger(
train_metrics_fname, reinitialize=False, logstyle=config["logstyle"]
)
# Write metadata
utils.write_metadata(config["logs_root"], experiment_name, config, state_dict)
else:
test_log = None
train_log = None
D_batch_size = (
config["batch_size"] * config["num_D_steps"] * config["num_D_accumulations"]
)
if config["longtail"]:
samples_per_class = np.load(
"imagenet_lt/imagenet_lt_samples_per_class.npy", allow_pickle=True
)
class_probabilities = np.load(
"imagenet_lt/imagenet_lt_class_prob.npy", allow_pickle=True
)
else:
samples_per_class, class_probabilities = None, None
train_dataset = data_utils.get_dataset_hdf5(
**{
**config,
"data_path": config["data_root"],
"batch_size": D_batch_size,
"augment": config["hflips"],
"local_rank": local_rank,
"copy_locally": copy_locally,
"tmp_dir": tmp_dir,
"ddp": config["ddp_train"],
}
)
train_loader = data_utils.get_dataloader(
**{
**config,
"dataset": train_dataset,
"batch_size": config["batch_size"],
"start_epoch": state_dict["epoch"],
"start_itr": state_dict["itr"],
"longtail_temperature": config["longtail_temperature"],
"samples_per_class": samples_per_class,
"class_probabilities": class_probabilities,
"rank": rank,
"world_size": world_size,
"shuffle": True,
"drop_last": True,
}
)
# Prepare inception metrics: FID and IS
is_moments_prefix = "I" if config["which_dataset"] == "imagenet" else "COCO"
im_filename = "%s%i_%s" % (
is_moments_prefix,
config["resolution"],
"" if not config["longtail"] else "longtail",
)
print("Using ", im_filename, "for Inception metrics.")
get_inception_metrics = inception_utils.prepare_inception_metrics(
im_filename,
samples_per_class,
config["parallel"],
config["no_fid"],
config["data_root"],
device=device,
)
G_batch_size = config["G_batch_size"]
z_, y_ = data_utils.prepare_z_y(
G_batch_size,
G.module.dim_z if config["ddp_train"] else G.dim_z,
config["n_classes"],
device=device,
fp16=config["G_fp16"],
longtail_gen=config["longtail_gen"],
custom_distrib=config["custom_distrib_gen"],
longtail_temperature=config["longtail_temperature"],
class_probabilities=class_probabilities,
)
# Balance instance sampling for ImageNet-LT
weights_sampling = None
if (
config["longtail"]
and config["use_balanced_sampler"]
and config["instance_cond"]
):
if config["which_knn_balance"] == "center_balance":
print(
"Balancing the instance features." "Using custom temperature distrib?",
config["custom_distrib_gen"],
" with temperature",
config["longtail_temperature"],
)
weights_sampling = data_utils.make_weights_for_balanced_classes(
samples_per_class,
train_loader.dataset.labels,
1000,
config["custom_distrib_gen"],
config["longtail_temperature"],
class_probabilities=class_probabilities,
)
# Balancing the NN classes (p(y))
elif config["which_knn_balance"] == "nnclass_balance":
print(
"Balancing the class distribution (classes drawn from the neighbors)."
" Using custom temperature distrib?",
config["custom_distrib_gen"],
" with temperature",
config["longtail_temperature"],
)
weights_sampling = torch.exp(
class_probabilities / config["longtail_temperature"]
) / torch.sum(
torch.exp(class_probabilities / config["longtail_temperature"])
)
# Configure conditioning sampling function to train G
sample_conditioning = functools.partial(
data_utils.sample_conditioning_values,
z_=z_,
y_=y_,
dataset=train_dataset,
batch_size=G_batch_size,
weights_sampling=weights_sampling,
ddp=config["ddp_train"],
constant_conditioning=config["constant_conditioning"],
class_cond=config["class_cond"],
instance_cond=config["instance_cond"],
nn_sampling_strategy=config["which_knn_balance"],
)
print("G batch size ", G_batch_size)
# Loaders are loaded, prepare the training function
train = train_fns.GAN_training_function(
G,
D,
GD,
ema,
state_dict,
config,
sample_conditioning,
embedded_optimizers=False,
device=device,
batch_size=G_batch_size,
)
# Prepare Sample function for use with inception metrics
sample = functools.partial(
utils.sample,
G=(G_ema if config["ema"] and config["use_ema"] else G),
sample_conditioning_func=sample_conditioning,
config=config,
class_cond=config["class_cond"],
instance_cond=config["instance_cond"],
)
print("Beginning training at epoch %d..." % state_dict["epoch"])
# Train for specified number of epochs, although we mostly track G iterations.
best_FID_run = state_dict["best_FID"]
FID = state_dict["best_FID"]
for epoch in range(state_dict["epoch"], config["num_epochs"]):
# Set epoch for distributed loader
if config["ddp_train"]:
train_loader.sampler.set_epoch(epoch)
# Initialize seeds at every epoch (useful for conditioning and
# noise sampling, as well as data order in the sampler)
if config["deterministic_run"]:
utils.seed_rng(config["seed"] + rank + state_dict["epoch"])
# Which progressbar to use? TQDM or my own?
if config["pbar"] == "mine":
pbar = utils.progress(
train_loader,
displaytype="s1k" if config["use_multiepoch_sampler"] else "eta",
)
else:
pbar = tqdm(train_loader)
s = time.time()
print("Before iteration, dataloader length", len(train_loader))
for i, batch in enumerate(pbar):
# if i> 5:
# break
in_label, in_feat = None, None
if config["instance_cond"] and config["class_cond"]:
x, in_label, in_feat, _ = batch
elif config["instance_cond"]:
x, in_feat, _ = batch
elif config["class_cond"]:
x, in_label = batch
if config["constant_conditioning"]:
in_label = torch.zeros_like(in_label)
else:
x = batch
x = x.to(device, non_blocking=True)
if in_label is not None:
in_label = in_label.to(device, non_blocking=True)
if in_feat is not None:
in_feat = in_feat.float().to(device, non_blocking=True)
# Increment the iteration counter
state_dict["itr"] += 1
# Make sure G and D are in training mode, just in case they got set to eval
# For D, which typically doesn't have BN, this shouldn't matter much.
G.train()
D.train()
if config["ema"]:
G_ema.train()
metrics = train(x, in_label, in_feat)
# print('After training step ', time.time() - s_stratified)
# s_stratified = time.time()
if rank == 0:
train_log.log(itr=int(state_dict["itr"]), **metrics)
# If using my progbar, print metrics.
if config["pbar"] == "mine" and rank == 0:
print(
", ".join(
["itr: %d" % state_dict["itr"]]
+ ["%s : %+4.3f" % (key, metrics[key]) for key in metrics]
),
end=" ",
)
# Test every specified interval
print("Iteration time ", time.time() - s)
s = time.time()
# Increment epoch counter at end of epoch
state_dict["epoch"] += 1
if not (state_dict["epoch"] % config["test_every"]):
if config["G_eval_mode"]:
print("Switching G to eval mode...")
G.eval()
D.eval()
# Compute IS and FID using training dataset as reference
test_time = time.time()
IS, FID = train_fns.test(
G,
D,
G_ema,
z_,
y_,
state_dict,
config,
sample,
get_inception_metrics,
experiment_name,
test_log,
loader=None,
embedded_optimizers=False,
G_optim=optimizer_G,
D_optim=optimizer_D,
rank=rank,
)
print("Testing took ", time.time() - test_time)
if 2 * IS < state_dict["best_IS"] and config["stop_when_diverge"]:
print("Experiment diverged!")
break
else:
print("IS is ", IS, " and 2x best is ", 2 * state_dict["best_IS"])
if not (state_dict["epoch"] % config["save_every"]) and rank == 0:
train_fns.save_weights(
G,
D,
G_ema,
state_dict,
config,
experiment_name,
embedded_optimizers=False,
G_optim=optimizer_G,
D_optim=optimizer_D,
)
if rank == 0:
if FID < best_FID_run:
best_FID_run = FID
state_dict["es_epoch"] = 0
else:
state_dict["es_epoch"] += 1
if state_dict["es_epoch"] >= config["es_patience"]:
print("reached Early stopping!")
return FID
return FID