ic_gan / BigGAN_PyTorch /train_fns.py
ArantxaCasanova
First model version
a00ee36
raw history blame
No virus
11.7 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# All contributions by Andy Brock:
# Copyright (c) 2019 Andy Brock
# MIT License
""" train_fns.py
Functions for the main loop of training different conditional image models
"""
import torch
import utils
import losses
# Dummy training function for debugging
def dummy_training_function():
def train(x, y):
return {}
return train
def GAN_training_function(
G,
D,
GD,
ema,
state_dict,
config,
sample_conditionings,
embedded_optimizers=True,
device="cuda",
batch_size=0,
):
def train(x, y=None, features=None):
if embedded_optimizers:
G.optim.zero_grad()
D.optim.zero_grad()
else:
GD.optimizer_D.zero_grad()
GD.optimizer_G.zero_grad()
# How many chunks to split x and y into?
x = torch.split(x, batch_size)
if y is not None:
y = torch.split(y, batch_size)
if features is not None:
f_ = torch.split(features, batch_size)
else:
f_ = None
counter = 0
# Optionally toggle D and G's "require_grad"
if config["toggle_grads"]:
utils.toggle_grad(D, True)
utils.toggle_grad(G, False)
for step_index in range(config["num_D_steps"]):
# If accumulating gradients, loop multiple times before an optimizer step
if embedded_optimizers:
D.optim.zero_grad()
else:
GD.optimizer_D.zero_grad()
for accumulation_index in range(config["num_D_accumulations"]):
# Sample conditioning for G
sampled_cond = sample_conditionings()
labels_g, f_g = None, None
if features is not None and y is not None:
z_, labels_g, f_g = sampled_cond
elif y is not None:
z_, labels_g = sampled_cond
elif features is not None:
z_, f_g = sampled_cond
# Tensors to device
if labels_g is not None:
labels_g = (
labels_g[:batch_size].to(device, non_blocking=True).long()
)
if f_g is not None:
f_g = f_g[:batch_size].to(device, non_blocking=True)
z_ = z_[:batch_size].to(device, non_blocking=True)
# Obtain discriminator scores
D_fake, D_real = GD(
z_,
labels_g,
f_g,
x[counter],
y[counter] if y is not None else None,
f_[counter] if f_ is not None else None,
train_G=False,
split_D=config["split_D"],
policy=config["DiffAugment"],
DA=config["DA"],
)
# Compute components of D's loss, average them, and divide by
# the number of gradient accumulations
D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real)
D_loss = (D_loss_real + D_loss_fake) / float(
config["num_D_accumulations"]
)
D_loss.backward()
counter += 1
# Optionally apply ortho reg in D
if config["D_ortho"] > 0.0:
# Debug print to indicate we're using ortho reg in D.
print("using modified ortho reg in D")
utils.ortho(D, config["D_ortho"])
if embedded_optimizers:
D.optim.step()
else:
GD.optimizer_D.step()
# Optionally toggle "requires_grad"
if config["toggle_grads"]:
utils.toggle_grad(D, False)
utils.toggle_grad(G, True)
# Zero G's gradients by default before training G, for safety
if embedded_optimizers:
G.optim.zero_grad()
else:
GD.optimizer_G.zero_grad()
counter = 0
# If accumulating gradients, loop multiple times
for accumulation_index in range(config["num_G_accumulations"]):
# Sample conditioning for G
sampled_cond = sample_conditionings()
labels_g, f_g = None, None
if features is not None and y is not None:
z_, labels_g, f_g = sampled_cond
elif y is not None:
z_, labels_g = sampled_cond
elif features is not None:
z_, f_g = sampled_cond
# Tensors to device
if labels_g is not None:
labels_g = labels_g.to(device, non_blocking=True).long()
if f_g is not None:
f_g = f_g.to(device, non_blocking=True)
z_ = z_.to(device, non_blocking=True)
# Obtain discriminator scores
D_fake = GD(
z_,
labels_g,
f_g,
train_G=True,
split_D=config["split_D"],
policy=config["DiffAugment"],
DA=config["DA"],
)
G_loss = losses.generator_loss(D_fake) / float(
config["num_G_accumulations"]
)
G_loss.backward()
counter += 1
# Optionally apply modified ortho reg in G
if config["G_ortho"] > 0.0:
print(
"using modified ortho reg in G"
) # Debug print to indicate we're using ortho reg in G
# Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
utils.ortho(
G,
config["G_ortho"],
blacklist=[param for param in G.shared.parameters()],
)
if embedded_optimizers:
G.optim.step()
else:
GD.optimizer_G.step()
# If we have an ema, update it, regardless of if we test with it or not
if config["ema"]:
ema.update(state_dict["itr"])
out = {
"G_loss": float(G_loss.item()),
"D_loss_real": float(D_loss_real.item()),
"D_loss_fake": float(D_loss_fake.item()),
}
# Return G's loss and the components of D's loss.
return out
return train
def save_weights(
G,
D,
G_ema,
state_dict,
config,
experiment_name,
embedded_optimizers=True,
G_optim=None,
D_optim=None,
):
utils.save_weights(
G,
D,
state_dict,
config["weights_root"],
experiment_name,
None,
G_ema if config["ema"] else None,
embedded_optimizers=embedded_optimizers,
G_optim=G_optim,
D_optim=D_optim,
)
# Save an additional copy to mitigate accidental corruption if process
# is killed during a save (it's happened to me before -.-)
if config["num_save_copies"] > 0:
utils.save_weights(
G,
D,
state_dict,
config["weights_root"],
experiment_name,
"copy%d" % state_dict["save_num"],
G_ema if config["ema"] else None,
embedded_optimizers=embedded_optimizers,
G_optim=G_optim,
D_optim=D_optim,
)
state_dict["save_num"] = (state_dict["save_num"] + 1) % config[
"num_save_copies"
]
""" This function takes in the model, saves the weights (multiple copies if
requested), and prepares sample sheets: one consisting of samples given
a fixed noise seed (to show how the model evolves throughout training),
a set of full conditional sample sheets, and a set of interp sheets. """
def save_and_sample(
G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name
):
utils.save_weights(
G,
D,
state_dict,
config["weights_root"],
experiment_name,
None,
G_ema if config["ema"] else None,
)
# Save an additional copy to mitigate accidental corruption if process
# is killed during a save (it's happened to me before -.-)
if config["num_save_copies"] > 0:
utils.save_weights(
G,
D,
state_dict,
config["weights_root"],
experiment_name,
"copy%d" % state_dict["save_num"],
G_ema if config["ema"] else None,
)
state_dict["save_num"] = (state_dict["save_num"] + 1) % config[
"num_save_copies"
]
# Accumulate standing statistics?
if config["accumulate_stats"]:
utils.accumulate_standing_stats(
G_ema if config["ema"] and config["use_ema"] else G,
z_,
y_,
config["n_classes"],
config["num_standing_accumulations"],
)
""" This function runs the inception metrics code, checks if the results
are an improvement over the previous best (either in IS or FID,
user-specified), logs the results, and saves a best_ copy if it's an
improvement. """
def test(
G,
D,
G_ema,
z_,
y_,
state_dict,
config,
sample,
get_inception_metrics,
experiment_name,
test_log,
loader=None,
embedded_optimizers=True,
G_optim=None,
D_optim=None,
rank=0,
):
print("Gathering inception metrics...")
if config["accumulate_stats"]:
utils.accumulate_standing_stats(
G_ema if config["ema"] and config["use_ema"] else G,
z_,
y_,
config["n_classes"],
config["num_standing_accumulations"],
)
if loader is not None:
IS_mean, IS_std, FID, stratified_FID, prdc_metrics = get_inception_metrics(
sample, config["num_inception_images"], num_splits=10, loader_ref=loader
)
else:
IS_mean, IS_std, FID, stratified_FID = get_inception_metrics(
sample, config["num_inception_images"], num_splits=10
)
print(
"Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f"
% (state_dict["itr"], IS_mean, IS_std, FID)
)
# If improved over previous best metric, save approrpiate copy
if rank == 0:
if (config["which_best"] == "IS" and IS_mean > state_dict["best_IS"]) or (
config["which_best"] == "FID" and FID < state_dict["best_FID"]
):
print(
"%s improved over previous best, saving checkpoint..."
% config["which_best"]
)
utils.save_weights(
G,
D,
state_dict,
config["weights_root"],
experiment_name,
"best%d" % state_dict["save_best_num"],
G_ema if config["ema"] else None,
embedded_optimizers=embedded_optimizers,
G_optim=G_optim,
D_optim=D_optim,
)
state_dict["save_best_num"] = (state_dict["save_best_num"] + 1) % config[
"num_best_copies"
]
state_dict["best_IS"] = max(state_dict["best_IS"], IS_mean)
state_dict["best_FID"] = min(state_dict["best_FID"], FID)
# Log results to file
test_log.log(
itr=int(state_dict["itr"]),
IS_mean=float(IS_mean),
IS_std=float(IS_std),
FID=float(FID),
)
return IS_mean, FID