File size: 2,156 Bytes
987f571
a0f925f
49b098d
a0f925f
49b098d
a0f925f
49b098d
 
a0f925f
 
49b098d
 
a0f925f
49b098d
 
 
 
a0f925f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49b098d
a0f925f
987f571
a0f925f
 
 
 
987f571
a0f925f
 
49b098d
987f571
a0f925f
49b098d
a0f925f
 
 
987f571
49b098d
a0f925f
 
49b098d
a0f925f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from torch.utils.data import random_split, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import MLFlowLogger

from src.trainer import LitTrainer


def argmax(a):
    return max(range(len(a)), key=lambda x: a[x])


def get_dataloaders(dataset, test_data):
    train_size = round(len(dataset) * 0.8)
    validate_size = len(dataset) - train_size
    train_data, validate_data = random_split(dataset, [train_size, validate_size])

    # For 8 CPU cores
    return DataLoader(train_data, num_workers=8), \
        DataLoader(validate_data, num_workers=8), \
        DataLoader(test_data, num_workers=8)


def train_loop(net, batch, loss_fn, optim, device="cuda"):
    x, y = batch
    x = x.to(device)
    y = y.to(device)

    y_pred = net(x).reshape(1, -1)
    loss = loss_fn(y_pred, y)
    truth_count = argmax(y_pred.flatten()) == y

    optim.zero_grad()
    loss.backward()
    optim.step()

    return loss.item(), truth_count


def train_net_manually(net, optim, loss_fn, train_loader, validate_loader=None, epochs=10, device="cuda"):
    for i in range(epochs):

        print("Epoch: 0")

        epoch_loss = 0
        epoch_truth_count = 0
        for idx, batch in enumerate(train_loader):
            loss, truth_count = train_loop(net, batch, loss_fn, optim, device)

            epoch_loss += loss
            epoch_truth_count += truth_count

            if idx % 1000 == 0:
                print(f"Loss: {loss} ({idx} / {len(train_loader)} x {i})")

        print(f"Epoch Loss: {epoch_loss}")
        print(f"Epoch Accuracy: {epoch_truth_count / len(train_loader)}")
    torch.save(net.state_dict(), "checkpoints/pytorch/version_1.pt")


def train_net_lightning(net, optim, loss_fn, train_loader, validate_loader=None, epochs=10):
    logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs")

    pl_net = LitTrainer(net)
    pl_net.optim = optim
    pl_net.loss = loss_fn
    trainer = pl.Trainer(limit_train_batches=100, max_epochs=epochs,
                         default_root_dir="../checkpoints", logger=logger)
    trainer.fit(pl_net, train_loader, validate_loader)