File size: 3,093 Bytes
159b4ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
"""
This file creates a simple lenet network using the MNIST dataset.
"""
import random
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
# Download the MNIST Dataset
def get_mnist_dataset():
transform = transforms.ToTensor()
train_set = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_set = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
return train_set, test_set
# Create the lenet model
class Classifier(torch.nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(1, 32, 5), # 28 -> 24
nn.ReLU(),
nn.MaxPool2d(2, 2), # 24 -> 12
nn.Conv2d(32, 32, 5), # 12 -> 8
nn.ReLU(),
nn.MaxPool2d(2, 2), # 8 -> 4
nn.Flatten(),
nn.Linear(32*4*4, 100),
nn.ReLU(),
nn.Linear(100, 100),
nn.ReLU(),
nn.Linear(100, 10)
)
def forward(self, x):
return self.network(x)
# Compute accuracy function
def compute_accuracy(model, data_set, nb_samples):
nb_valid = 0
for it in range(nb_samples):
# get a sample
sample_idx = torch.randint(len(data_set), size=(1,)).item()
img, label = data_set[sample_idx]
# compute the output
x = torch.reshape(img, (1,1,28,28))
y_h = model.forward(x)
pred_label = torch.argmax(y_h).item()
if label == pred_label :
nb_valid = nb_valid + 1
return nb_valid / nb_samples
# Train the model
def train_model(NB_ITERATION, CHECK_PERIOD, train_set, test_set, classifier):
accuracy_history = []
for it in range(NB_ITERATION):
sample_idx = random.randint(0, len(train_set)-1)
img, label = train_set[sample_idx]
x = torch.flatten(img)
x = torch.reshape(x, (1,1,28,28))
y = torch.zeros(1,10)
y[0][label] = 1
y_h = classifier.forward(x)
#print(y_h.shape, 'test')
l = F.mse_loss(y, y_h)
l.backward()
for p in classifier.parameters():
with torch.no_grad():
p -= 0.01 * p.grad
p.grad.zero_()
if it % CHECK_PERIOD == 0:
accuracy = compute_accuracy(classifier, test_set, CHECK_PERIOD)
accuracy_history.append(accuracy)
print(f'it {it}: accuracy = {accuracy:.8f} ')
def create_lenet():
# Get Dataset
train_set, test_set = get_mnist_dataset()
# Create model
classifier = Classifier()
# Train model
NB_ITERATION = 50000
CHECK_PERIOD = 3000
print("NB_ITERATIONS = ", NB_ITERATION)
print("CHECK_PERIOD = ", CHECK_PERIOD)
print("\nTraining LeNet...")
train_model(NB_ITERATION, CHECK_PERIOD, train_set, test_set, classifier)
# Export as ONNX
x = torch.Tensor(1,1,28,28)
torch.onnx.export(classifier.network, x, 'lenet.onnx', verbose=False, input_names=[ "input" ], output_names=[ "output" ]) |