|
""" |
|
This file creates a simple lenet network using the MNIST dataset. |
|
""" |
|
|
|
import random |
|
|
|
import torch |
|
from torchvision import datasets, transforms |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
def get_mnist_dataset(): |
|
transform = transforms.ToTensor() |
|
train_set = datasets.MNIST(root='./data', train=True, transform=transform, download=True) |
|
test_set = datasets.MNIST(root='./data', train=False, transform=transform, download=True) |
|
return train_set, test_set |
|
|
|
|
|
|
|
class Classifier(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.network = nn.Sequential( |
|
nn.Conv2d(1, 32, 5), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2, 2), |
|
nn.Conv2d(32, 32, 5), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2, 2), |
|
nn.Flatten(), |
|
nn.Linear(32*4*4, 100), |
|
nn.ReLU(), |
|
nn.Linear(100, 100), |
|
nn.ReLU(), |
|
nn.Linear(100, 10) |
|
) |
|
|
|
def forward(self, x): |
|
return self.network(x) |
|
|
|
|
|
|
|
def compute_accuracy(model, data_set, nb_samples): |
|
nb_valid = 0 |
|
for it in range(nb_samples): |
|
|
|
sample_idx = torch.randint(len(data_set), size=(1,)).item() |
|
img, label = data_set[sample_idx] |
|
|
|
x = torch.reshape(img, (1,1,28,28)) |
|
y_h = model.forward(x) |
|
pred_label = torch.argmax(y_h).item() |
|
if label == pred_label : |
|
nb_valid = nb_valid + 1 |
|
return nb_valid / nb_samples |
|
|
|
|
|
|
|
def train_model(NB_ITERATION, CHECK_PERIOD, train_set, test_set, classifier): |
|
accuracy_history = [] |
|
for it in range(NB_ITERATION): |
|
sample_idx = random.randint(0, len(train_set)-1) |
|
img, label = train_set[sample_idx] |
|
x = torch.flatten(img) |
|
x = torch.reshape(x, (1,1,28,28)) |
|
y = torch.zeros(1,10) |
|
y[0][label] = 1 |
|
y_h = classifier.forward(x) |
|
|
|
l = F.mse_loss(y, y_h) |
|
l.backward() |
|
for p in classifier.parameters(): |
|
with torch.no_grad(): |
|
p -= 0.01 * p.grad |
|
p.grad.zero_() |
|
|
|
if it % CHECK_PERIOD == 0: |
|
accuracy = compute_accuracy(classifier, test_set, CHECK_PERIOD) |
|
accuracy_history.append(accuracy) |
|
print(f'it {it}: accuracy = {accuracy:.8f} ') |
|
|
|
|
|
def create_lenet(): |
|
|
|
train_set, test_set = get_mnist_dataset() |
|
|
|
|
|
classifier = Classifier() |
|
|
|
|
|
NB_ITERATION = 50000 |
|
CHECK_PERIOD = 3000 |
|
print("NB_ITERATIONS = ", NB_ITERATION) |
|
print("CHECK_PERIOD = ", CHECK_PERIOD) |
|
print("\nTraining LeNet...") |
|
train_model(NB_ITERATION, CHECK_PERIOD, train_set, test_set, classifier) |
|
|
|
|
|
x = torch.Tensor(1,1,28,28) |
|
torch.onnx.export(classifier.network, x, 'lenet.onnx', verbose=False, input_names=[ "input" ], output_names=[ "output" ]) |