|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import time |
|
import copy |
|
import json |
|
import pickle |
|
import psutil |
|
import PIL.Image |
|
import numpy as np |
|
import shutil |
|
|
|
import torch |
|
import dnnlib |
|
from torch_utils import misc |
|
from torch_utils import training_stats |
|
from torch_utils.ops import conv2d_gradfix |
|
from torch_utils.ops import grid_sample_gradfix |
|
|
|
from torchvision import transforms |
|
|
|
import legacy |
|
from metrics import metric_main |
|
import torch.distributed as dist |
|
|
|
|
|
|
|
|
|
def setup_snapshot_image_grid(training_set, random_seed=0): |
|
rnd = np.random.RandomState(random_seed) |
|
gw = np.clip(7680 // training_set.image_shape[2], 7, 32) |
|
gh = np.clip(4320 // training_set.image_shape[1], 4, 32) |
|
|
|
|
|
if not training_set.has_labels: |
|
all_indices = list(range(len(training_set))) |
|
rnd.shuffle(all_indices) |
|
grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)] |
|
|
|
else: |
|
|
|
label_groups = dict() |
|
for idx in range(len(training_set)): |
|
label = tuple(training_set.get_details(idx).raw_label.flat[::-1]) |
|
if label not in label_groups: |
|
label_groups[label] = [] |
|
label_groups[label].append(idx) |
|
|
|
|
|
label_order = sorted(label_groups.keys()) |
|
for label in label_order: |
|
rnd.shuffle(label_groups[label]) |
|
|
|
|
|
grid_indices = [] |
|
for y in range(gh): |
|
label = label_order[y % len(label_order)] |
|
indices = label_groups[label] |
|
grid_indices += [indices[x % len(indices)] for x in range(gw)] |
|
label_groups[label] = [ |
|
indices[(i + gw) % len(indices)] for i in range(len(indices)) |
|
] |
|
|
|
|
|
images, labels = zip(*[training_set[i] for i in grid_indices]) |
|
return (gw, gh), np.stack(images), np.stack(labels) |
|
|
|
|
|
|
|
|
|
|
|
def save_image_grid(img, fname, drange, grid_size): |
|
lo, hi = drange |
|
img = np.asarray(img, dtype=np.float32) |
|
img = (img - lo) * (255 / (hi - lo)) |
|
img = np.rint(img).clip(0, 255).astype(np.uint8) |
|
|
|
gw, gh = grid_size |
|
_N, C, H, W = img.shape |
|
img = img.reshape(gh, gw, C, H, W) |
|
img = img.transpose(0, 3, 1, 4, 2) |
|
img = img.reshape(gh * H, gw * W, C) |
|
|
|
assert C in [1, 3] |
|
if C == 1: |
|
PIL.Image.fromarray(img[:, :, 0], "L").save(fname) |
|
if C == 3: |
|
PIL.Image.fromarray(img, "RGB").save(fname) |
|
|
|
|
|
|
|
|
|
|
|
def training_loop( |
|
exp_name="default_name", |
|
run_dir=".", |
|
temp_dir=".", |
|
training_set_kwargs={}, |
|
data_loader_kwargs={}, |
|
G_kwargs={}, |
|
D_kwargs={}, |
|
G_opt_kwargs={}, |
|
D_opt_kwargs={}, |
|
augment_kwargs=None, |
|
loss_kwargs={}, |
|
class_cond=False, |
|
instance_cond=False, |
|
metrics=[], |
|
random_seed=0, |
|
num_gpus=1, |
|
slurm=False, |
|
rank=0, |
|
local_rank=0, |
|
batch_size=4, |
|
batch_gpu=4, |
|
ema_kimg=10, |
|
ema_rampup=None, |
|
G_reg_interval=4, |
|
D_reg_interval=16, |
|
augment_p=0, |
|
ada_target=None, |
|
ada_interval=4, |
|
ada_kimg=500, |
|
total_kimg=25000, |
|
kimg_per_tick=4, |
|
image_snapshot_ticks=50, |
|
network_snapshot_ticks=50, |
|
es_patience=100000000, |
|
resume_pkl=None, |
|
cudnn_benchmark=True, |
|
allow_tf32=False, |
|
abort_fn=None, |
|
progress_fn=None, |
|
): |
|
|
|
start_time = time.time() |
|
|
|
device = "cuda:{}".format(local_rank) |
|
torch.cuda.set_device(device) |
|
|
|
np.random.seed(random_seed * num_gpus + rank) |
|
torch.manual_seed(random_seed * num_gpus + rank) |
|
torch.backends.cudnn.benchmark = cudnn_benchmark |
|
torch.backends.cuda.matmul.allow_tf32 = ( |
|
allow_tf32 |
|
) |
|
torch.backends.cudnn.allow_tf32 = ( |
|
allow_tf32 |
|
) |
|
conv2d_gradfix.enabled = True |
|
grid_sample_gradfix.enabled = True |
|
|
|
if slurm: |
|
img_filename = os.path.basename(training_set_kwargs.root) |
|
tmp_file_img = os.path.join(temp_dir, img_filename) |
|
if instance_cond: |
|
feats_filename = os.path.basename(training_set_kwargs.root_feats) |
|
tmp_file_feats = os.path.join(temp_dir, feats_filename) |
|
if local_rank == 0: |
|
print("start copying data locally") |
|
if not os.path.exists(tmp_file_img): |
|
shutil.copy2(training_set_kwargs.root, tmp_file_img) |
|
if instance_cond and not os.path.exists(tmp_file_feats): |
|
shutil.copy2(training_set_kwargs.root_feats, tmp_file_feats) |
|
print("finished copying data locally") |
|
dist.barrier() |
|
training_set_kwargs.root = tmp_file_img |
|
if instance_cond: |
|
training_set_kwargs.root_feats = tmp_file_feats |
|
print("Final path dataset ", training_set_kwargs.root) |
|
if instance_cond: |
|
print("Final path dataset (feats)", training_set_kwargs.root_feats) |
|
|
|
|
|
if rank == 0: |
|
print("Loading training set...") |
|
if training_set_kwargs.xflip: |
|
transform = transforms.RandomHorizontalFlip() |
|
else: |
|
transform = None |
|
training_set = dnnlib.util.construct_class_by_name( |
|
**{**training_set_kwargs, "transform": transform} |
|
) |
|
training_set_sampler = misc.InfiniteSampler( |
|
dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed |
|
) |
|
training_set_iterator = iter( |
|
torch.utils.data.DataLoader( |
|
dataset=training_set, |
|
sampler=training_set_sampler, |
|
batch_size=batch_size // num_gpus, |
|
**data_loader_kwargs, |
|
) |
|
) |
|
if rank == 0: |
|
print() |
|
print("Num images: ", len(training_set)) |
|
print("Image shape:", training_set.resolution) |
|
print("Label shape:", training_set.label_dim) |
|
print("Features shape:", training_set.feature_dim) |
|
print() |
|
|
|
|
|
if rank == 0: |
|
print("Constructing networks...") |
|
common_kwargs = dict( |
|
c_dim=training_set.label_dim if class_cond else 0, |
|
h_dim=training_set.feature_dim if instance_cond else 0, |
|
img_resolution=training_set.resolution, |
|
img_channels=3, |
|
) |
|
G = ( |
|
dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs) |
|
.train() |
|
.requires_grad_(False) |
|
.to(device) |
|
) |
|
D = ( |
|
dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs) |
|
.train() |
|
.requires_grad_(False) |
|
.to(device) |
|
) |
|
G_ema = copy.deepcopy(G).eval() |
|
|
|
snapshot_pkl_last = os.path.join(run_dir, "last_net") |
|
|
|
if num_gpus > 1: |
|
dist.barrier() |
|
if (resume_pkl is not None) and (rank == 0): |
|
print(f'Resuming from "{resume_pkl}"') |
|
with dnnlib.util.open_url(resume_pkl) as f: |
|
resume_data = legacy.load_network_pkl(f) |
|
for name, module in [("G", G), ("D", D), ("G_ema", G_ema)]: |
|
misc.copy_params_and_buffers(resume_data[name], module, require_all=False) |
|
print("Successfully loaded G,D,G_ema from specific pkl checkpoint") |
|
else: |
|
try: |
|
print(f'Resuming from "{snapshot_pkl_last}".pkl') |
|
with dnnlib.util.open_url(snapshot_pkl_last + ".pkl") as f: |
|
resume_data = legacy.load_network_pkl(f) |
|
for name, module in [("G", G), ("D", D), ("G_ema", G_ema)]: |
|
misc.copy_params_and_buffers( |
|
resume_data[name], module, require_all=False |
|
) |
|
print("Successfully loaded G,D,G_ema from last checkpoint") |
|
except: |
|
print("Starting training from scratch") |
|
|
|
|
|
if rank == 0: |
|
z = torch.empty([batch_gpu, G.z_dim], device=device) |
|
c = torch.empty([batch_gpu, G.c_dim], device=device) |
|
h = torch.empty([batch_gpu, G.h_dim], device=device) |
|
img = misc.print_module_summary(G, [z, c, h]) |
|
misc.print_module_summary(D, [img, c, h]) |
|
|
|
|
|
if rank == 0: |
|
print("Setting up augmentation...") |
|
augment_pipe = None |
|
ada_stats = None |
|
if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None): |
|
augment_pipe = ( |
|
dnnlib.util.construct_class_by_name(**augment_kwargs) |
|
.train() |
|
.requires_grad_(False) |
|
.to(device) |
|
) |
|
augment_pipe.p.copy_(torch.as_tensor(augment_p)) |
|
if ada_target is not None: |
|
ada_stats = training_stats.Collector(regex="Loss/signs/real") |
|
|
|
|
|
if rank == 0: |
|
print(f"Distributing across {num_gpus} GPUs...") |
|
ddp_modules = dict() |
|
for name, module in [ |
|
("G_mapping", G.mapping), |
|
("G_synthesis", G.synthesis), |
|
("D", D), |
|
(None, G_ema), |
|
("augment_pipe", augment_pipe), |
|
]: |
|
if ( |
|
(num_gpus > 1) |
|
and (module is not None) |
|
and len(list(module.parameters())) != 0 |
|
): |
|
module.requires_grad_(True) |
|
module = torch.nn.parallel.DistributedDataParallel( |
|
module, device_ids=[device], broadcast_buffers=False |
|
) |
|
module.requires_grad_(False) |
|
if name is not None: |
|
ddp_modules[name] = module |
|
|
|
|
|
if rank == 0: |
|
print("Setting up training phases...") |
|
loss = dnnlib.util.construct_class_by_name( |
|
device=device, **ddp_modules, **loss_kwargs |
|
) |
|
phases = [] |
|
for name, module, opt_kwargs, reg_interval in [ |
|
("G", G, G_opt_kwargs, G_reg_interval), |
|
("D", D, D_opt_kwargs, D_reg_interval), |
|
]: |
|
if reg_interval is None: |
|
opt = dnnlib.util.construct_class_by_name( |
|
params=module.parameters(), **opt_kwargs |
|
) |
|
phases += [ |
|
dnnlib.EasyDict(name=name + "both", module=module, opt=opt, interval=1) |
|
] |
|
else: |
|
mb_ratio = reg_interval / (reg_interval + 1) |
|
opt_kwargs = dnnlib.EasyDict(opt_kwargs) |
|
opt_kwargs.lr = opt_kwargs.lr * mb_ratio |
|
opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas] |
|
opt = dnnlib.util.construct_class_by_name( |
|
module.parameters(), **opt_kwargs |
|
) |
|
phases += [ |
|
dnnlib.EasyDict(name=name + "main", module=module, opt=opt, interval=1) |
|
] |
|
phases += [ |
|
dnnlib.EasyDict( |
|
name=name + "reg", module=module, opt=opt, interval=reg_interval |
|
) |
|
] |
|
for phase in phases: |
|
phase.start_event = None |
|
phase.end_event = None |
|
if rank == 0: |
|
phase.start_event = torch.cuda.Event(enable_timing=True) |
|
phase.end_event = torch.cuda.Event(enable_timing=True) |
|
|
|
|
|
if num_gpus > 1: |
|
dist.barrier() |
|
print("Resuming optimizers ") |
|
try: |
|
for phase in phases: |
|
phase["opt"].load_state_dict( |
|
torch.load( |
|
snapshot_pkl_last + phase["name"] + "_opt.pth", map_location=device |
|
) |
|
) |
|
print("All optimizers loaded from checkpoint! ") |
|
except: |
|
print("Could not load checkpoint! ", snapshot_pkl_last) |
|
print("Starting training from scratch") |
|
|
|
|
|
grid_size = None |
|
grid_z = None |
|
grid_c = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if rank == 0: |
|
print("Initializing logs...") |
|
stats_collector = training_stats.Collector(regex=".*") |
|
stats_metrics = dict() |
|
stats_jsonl = None |
|
stats_tfevents = None |
|
if rank == 0: |
|
stats_jsonl = open(os.path.join(run_dir, "stats.jsonl"), "wt") |
|
try: |
|
import torch.utils.tensorboard as tensorboard |
|
|
|
stats_tfevents = tensorboard.SummaryWriter(run_dir) |
|
except ImportError as err: |
|
print("Skipping tfevents export:", err) |
|
|
|
|
|
if rank == 0: |
|
print(f"Training for {total_kimg} kimg...") |
|
print() |
|
cur_nimg = 0 |
|
cur_tick = 0 |
|
try: |
|
(cur_tick, cur_nimg) = np.load( |
|
snapshot_pkl_last + "last_itr.npy", allow_pickle=True |
|
) |
|
print("Loading last tick and nimg ", cur_tick, cur_nimg) |
|
except: |
|
print("No last iter to load, starting from scratch") |
|
|
|
best_fid = 1000000 |
|
best_fid_nimg = 0 |
|
try: |
|
(_, best_fid_nimg, best_fid) = np.load( |
|
os.path.join(run_dir, "best_fid_itr.npy"), allow_pickle=True |
|
) |
|
print(f"Loading best fid itr {best_fid_nimg}, value of fid is {best_fid}") |
|
except: |
|
print("No last iter to load for best fid, starting from scratch") |
|
|
|
tick_start_nimg = cur_nimg |
|
tick_start_time = time.time() |
|
maintenance_time = tick_start_time - start_time |
|
batch_idx = 0 |
|
if progress_fn is not None: |
|
progress_fn(cur_nimg, total_kimg) |
|
while True: |
|
|
|
with torch.autograd.profiler.record_function("data_fetch"): |
|
batch = next(training_set_iterator) |
|
if instance_cond and class_cond: |
|
phase_real_img, phase_real_c, phase_real_h, _ = batch |
|
elif instance_cond: |
|
phase_real_img, phase_real_h, _ = batch |
|
phase_real_c = torch.empty([batch_gpu, G.c_dim], device=device) |
|
elif class_cond: |
|
phase_real_img, phase_real_c = batch |
|
phase_real_h = torch.empty([batch_gpu, G.h_dim], device=device) |
|
else: |
|
phase_real_img = batch |
|
phase_real_c = torch.empty([batch_gpu, G.c_dim], device=device) |
|
phase_real_h = torch.empty([batch_gpu, G.h_dim], device=device) |
|
|
|
phase_real_img = ( |
|
phase_real_img.to(device).to(torch.float32) / 127.5 - 1 |
|
).split(batch_gpu) |
|
all_gen_c = [ |
|
training_set.get_label(np.random.randint(len(training_set))) |
|
for _ in range(len(phases) * batch_size) |
|
] |
|
all_gen_h = [ |
|
training_set.get_instance_features(np.random.randint(len(training_set))) |
|
for _ in range(len(phases) * batch_size) |
|
] |
|
|
|
if class_cond: |
|
phase_real_c = phase_real_c.to(device).split(batch_gpu) |
|
if instance_cond: |
|
phase_real_h = phase_real_h.to(device).split(batch_gpu) |
|
all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device) |
|
all_gen_z = [ |
|
phase_gen_z.split(batch_gpu) |
|
for phase_gen_z in all_gen_z.split(batch_size) |
|
] |
|
all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device) |
|
all_gen_c = [ |
|
phase_gen_c.split(batch_gpu) |
|
for phase_gen_c in all_gen_c.split(batch_size) |
|
] |
|
all_gen_h = torch.from_numpy(np.stack(all_gen_h)).pin_memory().to(device) |
|
all_gen_h = [ |
|
phase_gen_h.split(batch_gpu) |
|
for phase_gen_h in all_gen_h.split(batch_size) |
|
] |
|
|
|
for phase, phase_gen_z, phase_gen_c, phase_gen_h in zip( |
|
phases, all_gen_z, all_gen_c, all_gen_h |
|
): |
|
if batch_idx % phase.interval != 0: |
|
continue |
|
|
|
|
|
if phase.start_event is not None: |
|
phase.start_event.record(torch.cuda.current_stream(device)) |
|
phase.opt.zero_grad(set_to_none=True) |
|
phase.module.requires_grad_(True) |
|
|
|
|
|
for round_idx, (real_img, real_c, real_h, gen_z, gen_c, gen_h) in enumerate( |
|
zip( |
|
phase_real_img, |
|
phase_real_c, |
|
phase_real_h, |
|
phase_gen_z, |
|
phase_gen_c, |
|
phase_gen_h, |
|
) |
|
): |
|
sync = round_idx == batch_size // (batch_gpu * num_gpus) - 1 |
|
gain = phase.interval |
|
loss.accumulate_gradients( |
|
phase=phase.name, |
|
real_img=real_img, |
|
real_c=real_c, |
|
real_h=real_h, |
|
gen_z=gen_z, |
|
gen_c=gen_c, |
|
gen_h=gen_h, |
|
sync=sync, |
|
gain=gain, |
|
) |
|
|
|
|
|
phase.module.requires_grad_(False) |
|
with torch.autograd.profiler.record_function(phase.name + "_opt"): |
|
for param in phase.module.parameters(): |
|
if param.grad is not None: |
|
misc.nan_to_num( |
|
param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad |
|
) |
|
phase.opt.step() |
|
if phase.end_event is not None: |
|
phase.end_event.record(torch.cuda.current_stream(device)) |
|
|
|
|
|
with torch.autograd.profiler.record_function("Gema"): |
|
ema_nimg = ema_kimg * 1000 |
|
if ema_rampup is not None: |
|
ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) |
|
ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8)) |
|
for p_ema, p in zip(G_ema.parameters(), G.parameters()): |
|
p_ema.copy_(p.lerp(p_ema, ema_beta)) |
|
for b_ema, b in zip(G_ema.buffers(), G.buffers()): |
|
b_ema.copy_(b) |
|
|
|
|
|
cur_nimg += batch_size |
|
batch_idx += 1 |
|
|
|
|
|
if (ada_stats is not None) and (batch_idx % ada_interval == 0): |
|
ada_stats.update() |
|
adjust = ( |
|
np.sign(ada_stats["Loss/signs/real"] - ada_target) |
|
* (batch_size * ada_interval) |
|
/ (ada_kimg * 1000) |
|
) |
|
augment_pipe.p.copy_( |
|
(augment_pipe.p + adjust).max(misc.constant(0, device=device)) |
|
) |
|
|
|
|
|
done = cur_nimg >= total_kimg * 1000 |
|
if ( |
|
(not done) |
|
and (cur_tick != 0) |
|
and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000) |
|
): |
|
continue |
|
|
|
|
|
tick_end_time = time.time() |
|
fields = [] |
|
fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"] |
|
fields += [ |
|
f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}" |
|
] |
|
fields += [ |
|
f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}" |
|
] |
|
fields += [ |
|
f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}" |
|
] |
|
fields += [ |
|
f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}" |
|
] |
|
fields += [ |
|
f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}" |
|
] |
|
fields += [ |
|
f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" |
|
] |
|
fields += [ |
|
f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}" |
|
] |
|
torch.cuda.reset_peak_memory_stats() |
|
fields += [ |
|
f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}" |
|
] |
|
training_stats.report0( |
|
"Timing/total_hours", (tick_end_time - start_time) / (60 * 60) |
|
) |
|
training_stats.report0( |
|
"Timing/total_days", (tick_end_time - start_time) / (24 * 60 * 60) |
|
) |
|
if rank == 0: |
|
print(" ".join(fields)) |
|
|
|
|
|
if (not done) and (abort_fn is not None) and abort_fn(): |
|
done = True |
|
if rank == 0: |
|
print() |
|
print("Aborting...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
snapshot_pkl = None |
|
snapshot_data = None |
|
if (network_snapshot_ticks is not None) and ( |
|
done or cur_tick % network_snapshot_ticks == 0 |
|
): |
|
snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs)) |
|
snapshot_pkl = os.path.join( |
|
run_dir, f"network-snapshot-{cur_nimg//1000:06d}.pkl" |
|
) |
|
for name, module in [ |
|
("G", G), |
|
("D", D), |
|
("G_ema", G_ema), |
|
("augment_pipe", augment_pipe), |
|
]: |
|
if module is not None: |
|
|
|
|
|
module = copy.deepcopy(module).eval().requires_grad_(False).cpu() |
|
snapshot_data[name] = module |
|
del module |
|
|
|
|
|
|
|
|
|
|
|
with open(snapshot_pkl_last + ".pkl", "wb") as f: |
|
pickle.dump(snapshot_data, f) |
|
for phase in phases: |
|
torch.save( |
|
phase["opt"].state_dict(), |
|
snapshot_pkl_last + phase["name"] + "_opt.pth", |
|
) |
|
np.save(snapshot_pkl_last + "last_itr", (cur_tick, cur_nimg)) |
|
|
|
|
|
if (snapshot_data is not None) and (len(metrics) > 0): |
|
if rank == 0: |
|
print("Evaluating metrics...") |
|
for metric in metrics: |
|
result_dict = metric_main.calc_metric( |
|
metric=metric, |
|
G=snapshot_data["G_ema"], |
|
dataset_kwargs=training_set_kwargs, |
|
num_gpus=num_gpus, |
|
rank=rank, |
|
device=device, |
|
) |
|
if rank == 0: |
|
metric_main.report_metric( |
|
result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl |
|
) |
|
stats_metrics.update(result_dict.results) |
|
|
|
if metric == "fid50k_full" and rank == 0: |
|
cur_fid = result_dict["results"]["fid50k_full"] |
|
if cur_fid < best_fid: |
|
print("Saving network snapshot with best FID") |
|
best_fid = cur_fid |
|
best_fid_nimg = cur_nimg |
|
snapshot_best_pkl = os.path.join( |
|
run_dir, f"best-network-snapshot.pkl" |
|
) |
|
with open(snapshot_best_pkl, "wb") as f: |
|
pickle.dump(snapshot_data, f) |
|
np.save( |
|
os.path.join(run_dir, "best_fid_itr"), |
|
(cur_tick, cur_nimg, best_fid), |
|
) |
|
else: |
|
if (cur_nimg - best_fid_nimg) > es_patience: |
|
done = True |
|
print("Stopping training due to early stopping.") |
|
del snapshot_data |
|
|
|
|
|
for phase in phases: |
|
value = [] |
|
if (phase.start_event is not None) and (phase.end_event is not None): |
|
phase.end_event.synchronize() |
|
value = phase.start_event.elapsed_time(phase.end_event) |
|
training_stats.report0("Timing/" + phase.name, value) |
|
stats_collector.update() |
|
stats_dict = stats_collector.as_dict() |
|
|
|
|
|
timestamp = time.time() |
|
if stats_jsonl is not None: |
|
fields = dict(stats_dict, timestamp=timestamp) |
|
stats_jsonl.write(json.dumps(fields) + "\n") |
|
stats_jsonl.flush() |
|
if stats_tfevents is not None: |
|
global_step = int(cur_nimg / 1e3) |
|
walltime = timestamp - start_time |
|
for name, value in stats_dict.items(): |
|
stats_tfevents.add_scalar( |
|
name, value.mean, global_step=global_step, walltime=walltime |
|
) |
|
for name, value in stats_metrics.items(): |
|
stats_tfevents.add_scalar( |
|
f"Metrics/{name}", value, global_step=global_step, walltime=walltime |
|
) |
|
stats_tfevents.flush() |
|
if progress_fn is not None: |
|
progress_fn(cur_nimg // 1000, total_kimg) |
|
|
|
|
|
cur_tick += 1 |
|
tick_start_nimg = cur_nimg |
|
tick_start_time = time.time() |
|
maintenance_time = tick_start_time - tick_end_time |
|
if done: |
|
break |
|
|
|
|
|
if rank == 0: |
|
print() |
|
print("Exiting...") |
|
|
|
|
|
|
|
|