File size: 1,856 Bytes
abe84f3
 
 
 
 
 
 
 
 
a0f925f
 
 
abe84f3
 
a0f925f
abe84f3
 
 
 
a0f925f
 
 
 
abe84f3
 
 
 
 
 
 
 
 
 
a0f925f
 
 
 
abe84f3
a0f925f
abe84f3
 
 
 
 
a0f925f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abe84f3
a0f925f
abe84f3
 
 
a0f925f
abe84f3
a0f925f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python
# coding: utf-8
import torch
from torch import nn
import numpy as np

import plotly.express as px
from plotly.subplots import make_subplots

from src.trainer import LitTrainer
from src.models import CNN
from src.dataset import DatasetMNIST, download_mnist


def load_pl_net(path="checkpoints/lightning_logs/version_26/checkpoints/epoch=9-step=1000.ckpt"):
    pl_net = LitTrainer.load_from_checkpoint(path, model=CNN(1, 10))
    return pl_net


def load_torch_net(path="checkpoints/pytorch/version_0.pt"):
    state_dict = torch.load(path)
    net = CNN(1, 10)
    net.load_state_dict(state_dict)
    return net


def get_sequence(model):
    fig = make_subplots(rows=2, cols=5)

    i, j = 0, np.random.randint(0, 30000)

    while i < 10:
        x, y = dataset[j]

        predicted, p = predict(x, model)

        if predicted == i and p > 0.95:
            img = np.flip(np.array(x.reshape(28, 28)), 0)
            fig.add_trace(px.imshow(img).data[0], row=int(i/5)+1, col=i % 5+1)
            i += 1
        j += 1
    return fig


def predict(x, model, device="cuda"):
    y_pred = model(x.to(device)).detach().cpu()
    predicted = int(np.argmax(y_pred))
    p = torch.max(nn.functional.softmax(y_pred, dim=0))

    return predicted, p


def predict_interval(x, model, device="cuda"):
    y_pred = model(x.to(device))

    print(y_pred)

    predicted = np.argsort(y_pred.cpu().detach().numpy())
    p = nn.functional.softmax(y_pred, dim=0)

    return {int(i): float(p[i]) for i in predicted}


if __name__ == "__main__":
    mnist = download_mnist("downloads/mnist/")
    dataset, test_data = DatasetMNIST(*mnist["train"]), DatasetMNIST(*mnist["test"])

    print("PyTorch Lightning Network")
    get_sequence(load_pl_net().to("cuda")).show()
    print("Manual Network")
    get_sequence(load_torch_net().to("cuda")).show()