|
import torch |
|
from torch import nn |
|
|
|
from src.utils import get_batch |
|
|
|
|
|
@torch.no_grad() |
|
def estimate_loss(model: nn.Module, eval_iters, block_size, batch_size, device): |
|
out = {} |
|
model.eval() |
|
for split in ["train", "val"]: |
|
losses = torch.zeros(eval_iters) |
|
for k in range(eval_iters): |
|
X, Y = get_batch(split, block_size, batch_size) |
|
X, Y = X.to(device), Y.to(device) |
|
logits, loss = model(X, Y) |
|
losses[k] = loss.item() |
|
out[split] = losses.mean() |
|
model.train() |
|
return out |
|
|
|
|
|
def train( |
|
model, |
|
optimizer, |
|
max_iters, |
|
eval_interval, |
|
eval_iters, |
|
block_size, |
|
batch_size, |
|
device, |
|
): |
|
val_loss = None |
|
for iter in range(max_iters): |
|
if iter % eval_interval == 0: |
|
losses = estimate_loss(model, eval_iters, block_size, batch_size, device) |
|
print( |
|
f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" |
|
) |
|
if val_loss is not None: |
|
if losses["val"] < val_loss: |
|
torch.save(model, "checkpoints/model.pth") |
|
else: |
|
val_loss = losses["val"] |
|
|
|
xb, yb = get_batch("train", block_size, batch_size) |
|
xb, yb = xb.to(device), yb.to(device) |
|
|
|
logits, loss = model(xb, yb) |
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
loss.backward() |
|
optimizer.step() |
|
|