deep_privacy2 / dp2 /gan_trainer.py
haakohu's picture
fix
44539fc
raw
history blame contribute delete
No virus
12.9 kB
import atexit
from collections import defaultdict
import logging
import typing
import torch
import time
from dp2.utils import vis_utils
from dp2 import utils
from tops import logger, checkpointer
import tops
from easydict import EasyDict
def accumulate_gradients(params, fp16_ddp_accumulate):
if len(params) == 0:
return
params = [param for param in params if param.grad is not None]
flat = torch.cat([param.grad.flatten() for param in params])
orig_dtype = flat.dtype
if tops.world_size() > 1:
if fp16_ddp_accumulate:
flat = flat.half() / tops.world_size()
else:
flat /= tops.world_size()
torch.distributed.all_reduce(flat)
flat = flat.to(orig_dtype)
grads = flat.split([param.numel() for param in params])
for param, grad in zip(params, grads):
param.grad = grad.reshape(param.shape)
def accumulate_buffers(module: torch.nn.Module):
buffers = [buf for buf in module.buffers()]
if len(buffers) == 0:
return
flat = torch.cat([buf.flatten() for buf in buffers])
if tops.world_size() > 1:
torch.distributed.all_reduce(flat)
flat /= tops.world_size()
bufs = flat.split([buf.numel() for buf in buffers])
for old, new in zip(buffers, bufs):
old.copy_(new.reshape(old.shape), non_blocking=True)
def check_ddp_consistency(module):
if tops.world_size() == 1:
return
assert isinstance(module, torch.nn.Module)
assert isinstance(module, torch.nn.Module)
params_buffs = list(module.named_parameters()) + list(module.named_buffers())
for name, tensor in params_buffs:
fullname = type(module).__name__ + '.' + name
tensor = tensor.detach()
if tensor.is_floating_point():
tensor = torch.nan_to_num(tensor)
other = tensor.clone()
torch.distributed.broadcast(tensor=other, src=0)
assert (tensor == other).all(), fullname
class AverageMeter():
def __init__(self) -> None:
self.to_log = dict()
self.n = defaultdict(int)
pass
@torch.no_grad()
def update(self, values: dict):
for key, value in values.items():
self.n[key] += 1
if key in self.to_log:
self.to_log[key] += value.mean().detach()
else:
self.to_log[key] = value.mean().detach()
def get_average(self):
return {key: value / self.n[key] for key, value in self.to_log.items()}
class GANTrainer:
def __init__(
self,
G: torch.nn.Module,
D: torch.nn.Module,
G_EMA: torch.nn.Module,
D_optim: torch.optim.Optimizer,
G_optim: torch.optim.Optimizer,
dl_train: typing.Iterator,
dl_val: typing.Iterable,
scaler_D: torch.cuda.amp.GradScaler,
scaler_G: torch.cuda.amp.GradScaler,
ims_per_log: int,
max_images_to_train: int,
loss_handler,
ims_per_val: int,
evaluate_fn,
batch_size: int,
broadcast_buffers: bool,
fp16_ddp_accumulate: bool,
save_state: bool,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.G = G
self.D = D
self.G_EMA = G_EMA
self.D_optim = D_optim
self.G_optim = G_optim
self.dl_train = dl_train
self.dl_val = dl_val
self.scaler_D = scaler_D
self.scaler_G = scaler_G
self.loss_handler = loss_handler
self.max_images_to_train = max_images_to_train
self.images_per_val = ims_per_val
self.images_per_log = ims_per_log
self.evaluate_fn = evaluate_fn
self.batch_size = batch_size
self.broadcast_buffers = broadcast_buffers
self.fp16_ddp_accumulate = fp16_ddp_accumulate
self.train_state = EasyDict(
next_log_step=0,
next_val_step=ims_per_val,
total_time=0
)
checkpointer.register_models(dict(
generator=G, discriminator=D, EMA_generator=G_EMA,
D_optimizer=D_optim,
G_optimizer=G_optim,
train_state=self.train_state,
scaler_D=self.scaler_D,
scaler_G=self.scaler_G
))
if checkpointer.has_checkpoint():
checkpointer.load_registered_models()
logger.log(f"Resuming training from: global step: {logger.global_step()}")
else:
logger.add_dict({
"stats/discriminator_parameters": tops.num_parameters(self.D),
"stats/generator_parameters": tops.num_parameters(self.G),
}, commit=False)
if save_state:
# If the job is unexpectedly killed, there could be a mismatch between previously saved checkpoint and the current checkpoint.
atexit.register(checkpointer.save_registered_models)
self._ims_per_log = ims_per_log
self.to_log = AverageMeter()
self.trainable_params_D = [param for param in self.D.parameters() if param.requires_grad]
self.trainable_params_G = [param for param in self.G.parameters() if param.requires_grad]
logger.add_dict({
"stats/discriminator_trainable_parameters": sum(p.numel() for p in self.trainable_params_D),
"stats/generator_trainable_parameters": sum(p.numel() for p in self.trainable_params_G),
}, commit=False, level=logging.INFO)
check_ddp_consistency(self.D)
check_ddp_consistency(self.G)
check_ddp_consistency(self.G_EMA.generator)
def train_loop(self):
self.log_time()
while logger.global_step() <= self.max_images_to_train:
batch = next(self.dl_train)
self.G_EMA.update_beta()
self.to_log.update(self.step_D(batch))
self.to_log.update(self.step_G(batch))
self.G_EMA.update(self.G)
if logger.global_step() >= self.train_state.next_log_step:
to_log = {f"loss/{key}": item.item() for key, item in self.to_log.get_average().items()}
to_log.update({"amp/grad_scale_G": self.scaler_G.get_scale()})
to_log.update({"amp/grad_scale_D": self.scaler_D.get_scale()})
self.to_log = AverageMeter()
logger.add_dict(to_log, commit=True)
self.train_state.next_log_step += self.images_per_log
if self.scaler_D.get_scale() < 1e-8 or self.scaler_G.get_scale() < 1e-8:
print("Stopping training as gradient scale < 1e-8")
logger.log("Stopping training as gradient scale < 1e-8")
break
if logger.global_step() >= self.train_state.next_val_step:
self.evaluate()
self.log_time()
self.save_images()
self.train_state.next_val_step += self.images_per_val
logger.step(self.batch_size*tops.world_size())
logger.log(f"Reached end of training at step {logger.global_step()}.")
checkpointer.save_registered_models()
def estimate_ims_per_hour(self):
batch = next(self.dl_train)
n_ims = int(100e3)
n_steps = int(n_ims / (self.batch_size * tops.world_size()))
n_ims = n_steps * self.batch_size * tops.world_size()
for i in range(10): # Warmup
self.G_EMA.update_beta()
self.step_D(batch)
self.step_G(batch)
self.G_EMA.update(self.G)
start_time = time.time()
for i in utils.tqdm_(list(range(n_steps))):
self.G_EMA.update_beta()
self.step_D(batch)
self.step_G(batch)
self.G_EMA.update(self.G)
total_time = time.time() - start_time
ims_per_sec = n_ims / total_time
ims_per_hour = ims_per_sec * 60*60
ims_per_day = ims_per_hour * 24
logger.log(f"Images per hour: {ims_per_hour/1e6:.3f}M")
logger.log(f"Images per day: {ims_per_day/1e6:.3f}M")
import math
ims_per_4_day = int(math.ceil(ims_per_day / tops.world_size() * 4))
logger.log(f"Images per 4 days: {ims_per_4_day}")
logger.add_dict({
"stats/ims_per_day": ims_per_day,
"stats/ims_per_4_day": ims_per_4_day
})
def log_time(self):
if not hasattr(self, "start_time"):
self.start_time = time.time()
self.last_time_step = logger.global_step()
return
n_images = logger.global_step() - self.last_time_step
if n_images == 0:
return
n_secs = time.time() - self.start_time
n_ims_per_sec = n_images / n_secs
training_time_hours = n_secs / 60 / 60
self.train_state.total_time += training_time_hours
remaining_images = self.max_images_to_train - logger.global_step()
remaining_time = remaining_images / n_ims_per_sec / 60 / 60
logger.add_dict({
"stats/n_ims_per_sec": n_ims_per_sec,
"stats/total_traing_time_hours": self.train_state.total_time,
"stats/remaining_time_hours": remaining_time
})
self.last_time_step = logger.global_step()
self.start_time = time.time()
def save_images(self):
dl_val = iter(self.dl_val)
batch = next(dl_val)
# TRUNCATED visualization
ims_to_log = 8
self.G_EMA.eval()
z = self.G.get_z(batch["img"])
fakes_truncated = self.G_EMA.sample(**batch, truncation_value=0)["img"]
fakes_truncated = utils.denormalize_img(fakes_truncated).mul(255).byte()[:ims_to_log].cpu()
if "__key__" in batch:
batch.pop("__key__")
real = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log]
to_vis = torch.cat((real, fakes_truncated))
logger.add_images("images/truncated", to_vis, nrow=2)
# Diverse images
ims_diverse = 3
batch = next(dl_val)
to_vis = []
for i in range(ims_diverse):
z = self.G.get_z(batch["img"])[:1].repeat(batch["img"].shape[0], 1)
fakes = utils.denormalize_img(self.G_EMA(**batch, z=z)["img"]).mul(255).byte()[:ims_to_log].cpu()
to_vis.append(fakes)
if "__key__" in batch:
batch.pop("__key__")
reals = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log]
to_vis.insert(0, reals)
to_vis = torch.cat(to_vis)
logger.add_images("images/diverse", to_vis, nrow=ims_diverse+1)
self.G_EMA.train()
pass
def evaluate(self):
logger.log("Stating evaluation.")
self.G_EMA.eval()
try:
checkpointer.save_registered_models(max_keep=3)
except Exception:
logger.log("Could not save checkpoint.")
if self.broadcast_buffers:
check_ddp_consistency(self.G)
check_ddp_consistency(self.D)
metrics = self.evaluate_fn(generator=self.G_EMA, dataloader=self.dl_val)
metrics = {f"metrics/{k}": v for k, v in metrics.items()}
logger.add_dict(metrics, level=logger.logger.INFO)
def step_D(self, batch):
utils.set_requires_grad(self.trainable_params_D, True)
utils.set_requires_grad(self.trainable_params_G, False)
tops.zero_grad(self.D)
loss, to_log = self.loss_handler.D_loss(batch, grad_scaler=self.scaler_D)
with torch.autograd.profiler.record_function("D_step"):
self.scaler_D.scale(loss).backward()
accumulate_gradients(self.trainable_params_D, fp16_ddp_accumulate=self.fp16_ddp_accumulate)
if self.broadcast_buffers:
accumulate_buffers(self.D)
accumulate_buffers(self.G)
# Step will not unscale if unscale is called previously.
self.scaler_D.step(self.D_optim)
self.scaler_D.update()
utils.set_requires_grad(self.trainable_params_D, False)
utils.set_requires_grad(self.trainable_params_G, False)
return to_log
def step_G(self, batch):
utils.set_requires_grad(self.trainable_params_D, False)
utils.set_requires_grad(self.trainable_params_G, True)
tops.zero_grad(self.G)
loss, to_log = self.loss_handler.G_loss(batch, grad_scaler=self.scaler_G)
with torch.autograd.profiler.record_function("G_step"):
self.scaler_G.scale(loss).backward()
accumulate_gradients(self.trainable_params_G, fp16_ddp_accumulate=self.fp16_ddp_accumulate)
if self.broadcast_buffers:
accumulate_buffers(self.G)
accumulate_buffers(self.D)
self.scaler_G.step(self.G_optim)
self.scaler_G.update()
utils.set_requires_grad(self.trainable_params_D, False)
utils.set_requires_grad(self.trainable_params_G, False)
return to_log