lenet_mnist / src /train.py
rzimmerdev's picture
Fixed devices and demo
a0f925f
raw
history blame
No virus
2.16 kB
import torch
from torch.utils.data import random_split, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import MLFlowLogger
from src.trainer import LitTrainer
def argmax(a):
return max(range(len(a)), key=lambda x: a[x])
def get_dataloaders(dataset, test_data):
train_size = round(len(dataset) * 0.8)
validate_size = len(dataset) - train_size
train_data, validate_data = random_split(dataset, [train_size, validate_size])
# For 8 CPU cores
return DataLoader(train_data, num_workers=8), \
DataLoader(validate_data, num_workers=8), \
DataLoader(test_data, num_workers=8)
def train_loop(net, batch, loss_fn, optim, device="cuda"):
x, y = batch
x = x.to(device)
y = y.to(device)
y_pred = net(x).reshape(1, -1)
loss = loss_fn(y_pred, y)
truth_count = argmax(y_pred.flatten()) == y
optim.zero_grad()
loss.backward()
optim.step()
return loss.item(), truth_count
def train_net_manually(net, optim, loss_fn, train_loader, validate_loader=None, epochs=10, device="cuda"):
for i in range(epochs):
print("Epoch: 0")
epoch_loss = 0
epoch_truth_count = 0
for idx, batch in enumerate(train_loader):
loss, truth_count = train_loop(net, batch, loss_fn, optim, device)
epoch_loss += loss
epoch_truth_count += truth_count
if idx % 1000 == 0:
print(f"Loss: {loss} ({idx} / {len(train_loader)} x {i})")
print(f"Epoch Loss: {epoch_loss}")
print(f"Epoch Accuracy: {epoch_truth_count / len(train_loader)}")
torch.save(net.state_dict(), "checkpoints/pytorch/version_1.pt")
def train_net_lightning(net, optim, loss_fn, train_loader, validate_loader=None, epochs=10):
logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs")
pl_net = LitTrainer(net)
pl_net.optim = optim
pl_net.loss = loss_fn
trainer = pl.Trainer(limit_train_batches=100, max_epochs=epochs,
default_root_dir="../checkpoints", logger=logger)
trainer.fit(pl_net, train_loader, validate_loader)