Spaces:
Sleeping
Sleeping
from tqdm import tqdm | |
from dkm.utils.utils import to_cuda | |
def train_step(train_batch, model, objective, optimizer, **kwargs): | |
optimizer.zero_grad() | |
out = model(train_batch) | |
l = objective(out, train_batch) | |
l.backward() | |
optimizer.step() | |
return {"train_out": out, "train_loss": l.item()} | |
def train_k_steps( | |
n_0, k, dataloader, model, objective, optimizer, lr_scheduler, progress_bar=True | |
): | |
for n in tqdm(range(n_0, n_0 + k), disable=not progress_bar): | |
batch = next(dataloader) | |
model.train(True) | |
batch = to_cuda(batch) | |
train_step( | |
train_batch=batch, | |
model=model, | |
objective=objective, | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler, | |
n=n, | |
) | |
lr_scheduler.step() | |
def train_epoch( | |
dataloader=None, | |
model=None, | |
objective=None, | |
optimizer=None, | |
lr_scheduler=None, | |
epoch=None, | |
): | |
model.train(True) | |
print(f"At epoch {epoch}") | |
for batch in tqdm(dataloader, mininterval=5.0): | |
batch = to_cuda(batch) | |
train_step( | |
train_batch=batch, model=model, objective=objective, optimizer=optimizer | |
) | |
lr_scheduler.step() | |
return { | |
"model": model, | |
"optimizer": optimizer, | |
"lr_scheduler": lr_scheduler, | |
"epoch": epoch, | |
} | |
def train_k_epochs( | |
start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler | |
): | |
for epoch in range(start_epoch, end_epoch + 1): | |
train_epoch( | |
dataloader=dataloader, | |
model=model, | |
objective=objective, | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler, | |
epoch=epoch, | |
) | |