from tqdm import tqdm from roma.utils.utils import to_cuda import roma import torch import wandb def log_param_statistics(named_parameters, norm_type=2): named_parameters = list(named_parameters) grads = [p.grad for n, p in named_parameters if p.grad is not None] weight_norms = [ p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None ] names = [n for n, p in named_parameters if p.grad is not None] param_norm = torch.stack(weight_norms).norm(p=norm_type) device = grads[0].device grad_norms = torch.stack( [torch.norm(g.detach(), norm_type).to(device) for g in grads] ) nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms) nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf] total_grad_norm = torch.norm(grad_norms, norm_type) if torch.any(nans_or_infs): print(f"These params have nan or inf grads: {nan_inf_names}") wandb.log({"grad_norm": total_grad_norm.item()}, step=roma.GLOBAL_STEP) wandb.log({"param_norm": param_norm.item()}, step=roma.GLOBAL_STEP) def train_step( train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm=1.0, **kwargs ): optimizer.zero_grad() out = model(train_batch) l = objective(out, train_batch) grad_scaler.scale(l).backward() grad_scaler.unscale_(optimizer) log_param_statistics(model.named_parameters()) torch.nn.utils.clip_grad_norm_( model.parameters(), grad_clip_norm ) # what should max norm be? grad_scaler.step(optimizer) grad_scaler.update() wandb.log({"grad_scale": grad_scaler._scale.item()}, step=roma.GLOBAL_STEP) if grad_scaler._scale < 1.0: grad_scaler._scale = torch.tensor(1.0).to(grad_scaler._scale) roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE # increment global step return {"train_out": out, "train_loss": l.item()} def train_k_steps( n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler, progress_bar=True, grad_clip_norm=1.0, warmup=None, ema_model=None, ): for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or roma.RANK > 0): batch = next(dataloader) model.train(True) batch = to_cuda(batch) train_step( train_batch=batch, model=model, objective=objective, optimizer=optimizer, lr_scheduler=lr_scheduler, grad_scaler=grad_scaler, n=n, grad_clip_norm=grad_clip_norm, ) if ema_model is not None: ema_model.update() if warmup is not None: with warmup.dampening(): lr_scheduler.step() else: lr_scheduler.step() [ wandb.log({f"lr_group_{grp}": lr}) for grp, lr in enumerate(lr_scheduler.get_last_lr()) ] def train_epoch( dataloader=None, model=None, objective=None, optimizer=None, lr_scheduler=None, epoch=None, ): model.train(True) print(f"At epoch {epoch}") for batch in tqdm(dataloader, mininterval=5.0): batch = to_cuda(batch) train_step( train_batch=batch, model=model, objective=objective, optimizer=optimizer ) lr_scheduler.step() return { "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler, "epoch": epoch, } def train_k_epochs( start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler ): for epoch in range(start_epoch, end_epoch + 1): train_epoch( dataloader=dataloader, model=model, objective=objective, optimizer=optimizer, lr_scheduler=lr_scheduler, epoch=epoch, )