Spaces:
Sleeping
Sleeping
File size: 2,156 Bytes
987f571 a0f925f 49b098d a0f925f 49b098d a0f925f 49b098d a0f925f 49b098d a0f925f 49b098d a0f925f 49b098d a0f925f 987f571 a0f925f 987f571 a0f925f 49b098d 987f571 a0f925f 49b098d a0f925f 987f571 49b098d a0f925f 49b098d a0f925f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch
from torch.utils.data import random_split, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import MLFlowLogger
from src.trainer import LitTrainer
def argmax(a):
return max(range(len(a)), key=lambda x: a[x])
def get_dataloaders(dataset, test_data):
train_size = round(len(dataset) * 0.8)
validate_size = len(dataset) - train_size
train_data, validate_data = random_split(dataset, [train_size, validate_size])
# For 8 CPU cores
return DataLoader(train_data, num_workers=8), \
DataLoader(validate_data, num_workers=8), \
DataLoader(test_data, num_workers=8)
def train_loop(net, batch, loss_fn, optim, device="cuda"):
x, y = batch
x = x.to(device)
y = y.to(device)
y_pred = net(x).reshape(1, -1)
loss = loss_fn(y_pred, y)
truth_count = argmax(y_pred.flatten()) == y
optim.zero_grad()
loss.backward()
optim.step()
return loss.item(), truth_count
def train_net_manually(net, optim, loss_fn, train_loader, validate_loader=None, epochs=10, device="cuda"):
for i in range(epochs):
print("Epoch: 0")
epoch_loss = 0
epoch_truth_count = 0
for idx, batch in enumerate(train_loader):
loss, truth_count = train_loop(net, batch, loss_fn, optim, device)
epoch_loss += loss
epoch_truth_count += truth_count
if idx % 1000 == 0:
print(f"Loss: {loss} ({idx} / {len(train_loader)} x {i})")
print(f"Epoch Loss: {epoch_loss}")
print(f"Epoch Accuracy: {epoch_truth_count / len(train_loader)}")
torch.save(net.state_dict(), "checkpoints/pytorch/version_1.pt")
def train_net_lightning(net, optim, loss_fn, train_loader, validate_loader=None, epochs=10):
logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs")
pl_net = LitTrainer(net)
pl_net.optim = optim
pl_net.loss = loss_fn
trainer = pl.Trainer(limit_train_batches=100, max_epochs=epochs,
default_root_dir="../checkpoints", logger=logger)
trainer.fit(pl_net, train_loader, validate_loader)
|