File size: 3,093 Bytes
159b4ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
This file creates a simple lenet network using the MNIST dataset.
"""

import random

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F

# Download the MNIST Dataset

def get_mnist_dataset():
    transform = transforms.ToTensor()
    train_set = datasets.MNIST(root='./data', train=True,  transform=transform, download=True)
    test_set  = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
    return train_set, test_set

# Create the lenet model

class Classifier(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(1, 32, 5),  # 28 -> 24
            nn.ReLU(),
            nn.MaxPool2d(2, 2),   # 24 -> 12
            nn.Conv2d(32, 32, 5), # 12 ->  8
            nn.ReLU(),
            nn.MaxPool2d(2, 2),   #  8 ->  4
            nn.Flatten(),
            nn.Linear(32*4*4, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 10)
        )

    def forward(self, x):
        return self.network(x)

# Compute accuracy function

def compute_accuracy(model, data_set, nb_samples):
    nb_valid = 0
    for it in range(nb_samples):
        # get a sample
        sample_idx = torch.randint(len(data_set), size=(1,)).item()
        img, label = data_set[sample_idx]
        # compute the output
        x = torch.reshape(img, (1,1,28,28))
        y_h = model.forward(x)
        pred_label = torch.argmax(y_h).item()
        if label == pred_label :
            nb_valid = nb_valid + 1
    return nb_valid / nb_samples

# Train the model

def train_model(NB_ITERATION, CHECK_PERIOD, train_set, test_set, classifier):
    accuracy_history = []
    for it in range(NB_ITERATION):
        sample_idx = random.randint(0, len(train_set)-1)
        img, label = train_set[sample_idx]
        x = torch.flatten(img)
        x = torch.reshape(x, (1,1,28,28))
        y = torch.zeros(1,10)
        y[0][label] = 1
        y_h = classifier.forward(x)
        #print(y_h.shape, 'test')
        l = F.mse_loss(y, y_h)
        l.backward()
        for p in classifier.parameters():
            with torch.no_grad():
                p -= 0.01 * p.grad
            p.grad.zero_()

        if it % CHECK_PERIOD == 0:
            accuracy = compute_accuracy(classifier, test_set, CHECK_PERIOD)
            accuracy_history.append(accuracy)
            print(f'it {it}: accuracy = {accuracy:.8f} ')


def create_lenet():
    # Get Dataset
    train_set, test_set = get_mnist_dataset()

    # Create model
    classifier = Classifier()

    # Train model
    NB_ITERATION = 50000
    CHECK_PERIOD = 3000
    print("NB_ITERATIONS = ", NB_ITERATION)
    print("CHECK_PERIOD  = ", CHECK_PERIOD)
    print("\nTraining LeNet...")
    train_model(NB_ITERATION, CHECK_PERIOD, train_set, test_set, classifier)

    # Export as ONNX
    x = torch.Tensor(1,1,28,28)
    torch.onnx.export(classifier.network, x, 'lenet.onnx', verbose=False, input_names=[ "input" ], output_names=[ "output" ])