|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" train_fns.py |
|
Functions for the main loop of training different conditional image models |
|
""" |
|
import torch |
|
|
|
import utils |
|
import losses |
|
|
|
|
|
|
|
def dummy_training_function(): |
|
def train(x, y): |
|
return {} |
|
|
|
return train |
|
|
|
|
|
def GAN_training_function( |
|
G, |
|
D, |
|
GD, |
|
ema, |
|
state_dict, |
|
config, |
|
sample_conditionings, |
|
embedded_optimizers=True, |
|
device="cuda", |
|
batch_size=0, |
|
): |
|
def train(x, y=None, features=None): |
|
if embedded_optimizers: |
|
G.optim.zero_grad() |
|
D.optim.zero_grad() |
|
else: |
|
GD.optimizer_D.zero_grad() |
|
GD.optimizer_G.zero_grad() |
|
|
|
x = torch.split(x, batch_size) |
|
if y is not None: |
|
y = torch.split(y, batch_size) |
|
if features is not None: |
|
f_ = torch.split(features, batch_size) |
|
else: |
|
f_ = None |
|
counter = 0 |
|
|
|
|
|
if config["toggle_grads"]: |
|
utils.toggle_grad(D, True) |
|
utils.toggle_grad(G, False) |
|
|
|
for step_index in range(config["num_D_steps"]): |
|
|
|
if embedded_optimizers: |
|
D.optim.zero_grad() |
|
else: |
|
GD.optimizer_D.zero_grad() |
|
for accumulation_index in range(config["num_D_accumulations"]): |
|
|
|
sampled_cond = sample_conditionings() |
|
labels_g, f_g = None, None |
|
if features is not None and y is not None: |
|
z_, labels_g, f_g = sampled_cond |
|
elif y is not None: |
|
z_, labels_g = sampled_cond |
|
elif features is not None: |
|
z_, f_g = sampled_cond |
|
|
|
if labels_g is not None: |
|
labels_g = ( |
|
labels_g[:batch_size].to(device, non_blocking=True).long() |
|
) |
|
if f_g is not None: |
|
f_g = f_g[:batch_size].to(device, non_blocking=True) |
|
z_ = z_[:batch_size].to(device, non_blocking=True) |
|
|
|
D_fake, D_real = GD( |
|
z_, |
|
labels_g, |
|
f_g, |
|
x[counter], |
|
y[counter] if y is not None else None, |
|
f_[counter] if f_ is not None else None, |
|
train_G=False, |
|
split_D=config["split_D"], |
|
policy=config["DiffAugment"], |
|
DA=config["DA"], |
|
) |
|
|
|
|
|
|
|
D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real) |
|
D_loss = (D_loss_real + D_loss_fake) / float( |
|
config["num_D_accumulations"] |
|
) |
|
D_loss.backward() |
|
counter += 1 |
|
|
|
|
|
if config["D_ortho"] > 0.0: |
|
|
|
print("using modified ortho reg in D") |
|
utils.ortho(D, config["D_ortho"]) |
|
|
|
if embedded_optimizers: |
|
D.optim.step() |
|
else: |
|
GD.optimizer_D.step() |
|
|
|
|
|
if config["toggle_grads"]: |
|
utils.toggle_grad(D, False) |
|
utils.toggle_grad(G, True) |
|
|
|
|
|
if embedded_optimizers: |
|
G.optim.zero_grad() |
|
else: |
|
GD.optimizer_G.zero_grad() |
|
|
|
counter = 0 |
|
|
|
for accumulation_index in range(config["num_G_accumulations"]): |
|
|
|
sampled_cond = sample_conditionings() |
|
labels_g, f_g = None, None |
|
if features is not None and y is not None: |
|
z_, labels_g, f_g = sampled_cond |
|
elif y is not None: |
|
z_, labels_g = sampled_cond |
|
elif features is not None: |
|
z_, f_g = sampled_cond |
|
|
|
if labels_g is not None: |
|
labels_g = labels_g.to(device, non_blocking=True).long() |
|
if f_g is not None: |
|
f_g = f_g.to(device, non_blocking=True) |
|
z_ = z_.to(device, non_blocking=True) |
|
|
|
D_fake = GD( |
|
z_, |
|
labels_g, |
|
f_g, |
|
train_G=True, |
|
split_D=config["split_D"], |
|
policy=config["DiffAugment"], |
|
DA=config["DA"], |
|
) |
|
G_loss = losses.generator_loss(D_fake) / float( |
|
config["num_G_accumulations"] |
|
) |
|
G_loss.backward() |
|
counter += 1 |
|
|
|
|
|
if config["G_ortho"] > 0.0: |
|
print( |
|
"using modified ortho reg in G" |
|
) |
|
|
|
utils.ortho( |
|
G, |
|
config["G_ortho"], |
|
blacklist=[param for param in G.shared.parameters()], |
|
) |
|
if embedded_optimizers: |
|
G.optim.step() |
|
else: |
|
GD.optimizer_G.step() |
|
|
|
|
|
if config["ema"]: |
|
ema.update(state_dict["itr"]) |
|
|
|
out = { |
|
"G_loss": float(G_loss.item()), |
|
"D_loss_real": float(D_loss_real.item()), |
|
"D_loss_fake": float(D_loss_fake.item()), |
|
} |
|
|
|
return out |
|
|
|
return train |
|
|
|
|
|
def save_weights( |
|
G, |
|
D, |
|
G_ema, |
|
state_dict, |
|
config, |
|
experiment_name, |
|
embedded_optimizers=True, |
|
G_optim=None, |
|
D_optim=None, |
|
): |
|
utils.save_weights( |
|
G, |
|
D, |
|
state_dict, |
|
config["weights_root"], |
|
experiment_name, |
|
None, |
|
G_ema if config["ema"] else None, |
|
embedded_optimizers=embedded_optimizers, |
|
G_optim=G_optim, |
|
D_optim=D_optim, |
|
) |
|
|
|
|
|
if config["num_save_copies"] > 0: |
|
utils.save_weights( |
|
G, |
|
D, |
|
state_dict, |
|
config["weights_root"], |
|
experiment_name, |
|
"copy%d" % state_dict["save_num"], |
|
G_ema if config["ema"] else None, |
|
embedded_optimizers=embedded_optimizers, |
|
G_optim=G_optim, |
|
D_optim=D_optim, |
|
) |
|
state_dict["save_num"] = (state_dict["save_num"] + 1) % config[ |
|
"num_save_copies" |
|
] |
|
|
|
|
|
""" This function takes in the model, saves the weights (multiple copies if |
|
requested), and prepares sample sheets: one consisting of samples given |
|
a fixed noise seed (to show how the model evolves throughout training), |
|
a set of full conditional sample sheets, and a set of interp sheets. """ |
|
|
|
|
|
def save_and_sample( |
|
G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name |
|
): |
|
utils.save_weights( |
|
G, |
|
D, |
|
state_dict, |
|
config["weights_root"], |
|
experiment_name, |
|
None, |
|
G_ema if config["ema"] else None, |
|
) |
|
|
|
|
|
if config["num_save_copies"] > 0: |
|
utils.save_weights( |
|
G, |
|
D, |
|
state_dict, |
|
config["weights_root"], |
|
experiment_name, |
|
"copy%d" % state_dict["save_num"], |
|
G_ema if config["ema"] else None, |
|
) |
|
state_dict["save_num"] = (state_dict["save_num"] + 1) % config[ |
|
"num_save_copies" |
|
] |
|
|
|
|
|
if config["accumulate_stats"]: |
|
utils.accumulate_standing_stats( |
|
G_ema if config["ema"] and config["use_ema"] else G, |
|
z_, |
|
y_, |
|
config["n_classes"], |
|
config["num_standing_accumulations"], |
|
) |
|
|
|
|
|
""" This function runs the inception metrics code, checks if the results |
|
are an improvement over the previous best (either in IS or FID, |
|
user-specified), logs the results, and saves a best_ copy if it's an |
|
improvement. """ |
|
|
|
|
|
def test( |
|
G, |
|
D, |
|
G_ema, |
|
z_, |
|
y_, |
|
state_dict, |
|
config, |
|
sample, |
|
get_inception_metrics, |
|
experiment_name, |
|
test_log, |
|
loader=None, |
|
embedded_optimizers=True, |
|
G_optim=None, |
|
D_optim=None, |
|
rank=0, |
|
): |
|
print("Gathering inception metrics...") |
|
if config["accumulate_stats"]: |
|
utils.accumulate_standing_stats( |
|
G_ema if config["ema"] and config["use_ema"] else G, |
|
z_, |
|
y_, |
|
config["n_classes"], |
|
config["num_standing_accumulations"], |
|
) |
|
if loader is not None: |
|
IS_mean, IS_std, FID, stratified_FID, prdc_metrics = get_inception_metrics( |
|
sample, config["num_inception_images"], num_splits=10, loader_ref=loader |
|
) |
|
else: |
|
IS_mean, IS_std, FID, stratified_FID = get_inception_metrics( |
|
sample, config["num_inception_images"], num_splits=10 |
|
) |
|
print( |
|
"Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f" |
|
% (state_dict["itr"], IS_mean, IS_std, FID) |
|
) |
|
|
|
if rank == 0: |
|
if (config["which_best"] == "IS" and IS_mean > state_dict["best_IS"]) or ( |
|
config["which_best"] == "FID" and FID < state_dict["best_FID"] |
|
): |
|
print( |
|
"%s improved over previous best, saving checkpoint..." |
|
% config["which_best"] |
|
) |
|
utils.save_weights( |
|
G, |
|
D, |
|
state_dict, |
|
config["weights_root"], |
|
experiment_name, |
|
"best%d" % state_dict["save_best_num"], |
|
G_ema if config["ema"] else None, |
|
embedded_optimizers=embedded_optimizers, |
|
G_optim=G_optim, |
|
D_optim=D_optim, |
|
) |
|
state_dict["save_best_num"] = (state_dict["save_best_num"] + 1) % config[ |
|
"num_best_copies" |
|
] |
|
state_dict["best_IS"] = max(state_dict["best_IS"], IS_mean) |
|
state_dict["best_FID"] = min(state_dict["best_FID"], FID) |
|
|
|
test_log.log( |
|
itr=int(state_dict["itr"]), |
|
IS_mean=float(IS_mean), |
|
IS_std=float(IS_std), |
|
FID=float(FID), |
|
) |
|
return IS_mean, FID |
|
|