Spaces:
Running
Running
from tqdm import tqdm | |
from roma.utils.utils import to_cuda | |
import roma | |
import torch | |
import wandb | |
def log_param_statistics(named_parameters, norm_type=2): | |
named_parameters = list(named_parameters) | |
grads = [p.grad for n, p in named_parameters if p.grad is not None] | |
weight_norms = [ | |
p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None | |
] | |
names = [n for n, p in named_parameters if p.grad is not None] | |
param_norm = torch.stack(weight_norms).norm(p=norm_type) | |
device = grads[0].device | |
grad_norms = torch.stack( | |
[torch.norm(g.detach(), norm_type).to(device) for g in grads] | |
) | |
nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms) | |
nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf] | |
total_grad_norm = torch.norm(grad_norms, norm_type) | |
if torch.any(nans_or_infs): | |
print(f"These params have nan or inf grads: {nan_inf_names}") | |
wandb.log({"grad_norm": total_grad_norm.item()}, step=roma.GLOBAL_STEP) | |
wandb.log({"param_norm": param_norm.item()}, step=roma.GLOBAL_STEP) | |
def train_step( | |
train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm=1.0, **kwargs | |
): | |
optimizer.zero_grad() | |
out = model(train_batch) | |
l = objective(out, train_batch) | |
grad_scaler.scale(l).backward() | |
grad_scaler.unscale_(optimizer) | |
log_param_statistics(model.named_parameters()) | |
torch.nn.utils.clip_grad_norm_( | |
model.parameters(), grad_clip_norm | |
) # what should max norm be? | |
grad_scaler.step(optimizer) | |
grad_scaler.update() | |
wandb.log({"grad_scale": grad_scaler._scale.item()}, step=roma.GLOBAL_STEP) | |
if grad_scaler._scale < 1.0: | |
grad_scaler._scale = torch.tensor(1.0).to(grad_scaler._scale) | |
roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE # increment global step | |
return {"train_out": out, "train_loss": l.item()} | |
def train_k_steps( | |
n_0, | |
k, | |
dataloader, | |
model, | |
objective, | |
optimizer, | |
lr_scheduler, | |
grad_scaler, | |
progress_bar=True, | |
grad_clip_norm=1.0, | |
warmup=None, | |
ema_model=None, | |
): | |
for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or roma.RANK > 0): | |
batch = next(dataloader) | |
model.train(True) | |
batch = to_cuda(batch) | |
train_step( | |
train_batch=batch, | |
model=model, | |
objective=objective, | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler, | |
grad_scaler=grad_scaler, | |
n=n, | |
grad_clip_norm=grad_clip_norm, | |
) | |
if ema_model is not None: | |
ema_model.update() | |
if warmup is not None: | |
with warmup.dampening(): | |
lr_scheduler.step() | |
else: | |
lr_scheduler.step() | |
[ | |
wandb.log({f"lr_group_{grp}": lr}) | |
for grp, lr in enumerate(lr_scheduler.get_last_lr()) | |
] | |
def train_epoch( | |
dataloader=None, | |
model=None, | |
objective=None, | |
optimizer=None, | |
lr_scheduler=None, | |
epoch=None, | |
): | |
model.train(True) | |
print(f"At epoch {epoch}") | |
for batch in tqdm(dataloader, mininterval=5.0): | |
batch = to_cuda(batch) | |
train_step( | |
train_batch=batch, model=model, objective=objective, optimizer=optimizer | |
) | |
lr_scheduler.step() | |
return { | |
"model": model, | |
"optimizer": optimizer, | |
"lr_scheduler": lr_scheduler, | |
"epoch": epoch, | |
} | |
def train_k_epochs( | |
start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler | |
): | |
for epoch in range(start_epoch, end_epoch + 1): | |
train_epoch( | |
dataloader=dataloader, | |
model=model, | |
objective=objective, | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler, | |
epoch=epoch, | |
) | |