from tqdm import tqdm from dkm.utils.utils import to_cuda def train_step(train_batch, model, objective, optimizer, **kwargs): optimizer.zero_grad() out = model(train_batch) l = objective(out, train_batch) l.backward() optimizer.step() return {"train_out": out, "train_loss": l.item()} def train_k_steps( n_0, k, dataloader, model, objective, optimizer, lr_scheduler, progress_bar=True ): for n in tqdm(range(n_0, n_0 + k), disable=not progress_bar): batch = next(dataloader) model.train(True) batch = to_cuda(batch) train_step( train_batch=batch, model=model, objective=objective, optimizer=optimizer, lr_scheduler=lr_scheduler, n=n, ) lr_scheduler.step() def train_epoch( dataloader=None, model=None, objective=None, optimizer=None, lr_scheduler=None, epoch=None, ): model.train(True) print(f"At epoch {epoch}") for batch in tqdm(dataloader, mininterval=5.0): batch = to_cuda(batch) train_step( train_batch=batch, model=model, objective=objective, optimizer=optimizer ) lr_scheduler.step() return { "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler, "epoch": epoch, } def train_k_epochs( start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler ): for epoch in range(start_epoch, end_epoch + 1): train_epoch( dataloader=dataloader, model=model, objective=objective, optimizer=optimizer, lr_scheduler=lr_scheduler, epoch=epoch, )