import torch import torchvision import multiprocessing, prettytable import torchvision.transforms as transforms from neural_network import MNISTNetwork # hyperparameters BATCH_SIZE = 64 NUM_WORKERS = 2 EPOCH = 15 LEARNING_RATE = 0.01 MOMENTUM = 0.5 LOSS = torch.nn.CrossEntropyLoss() ## Step 1: define our transforms transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.5), (0.5)) ] ) ## Step 2: get our datasets full_ds = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform) train_size = int(0.8 * len(full_ds)) # Use 80% of the data for training val_size = len(full_ds) - train_size # Use the remaining 20% for validation train_ds, valid_ds = torch.utils.data.random_split(full_ds, [train_size, val_size]) test_ds = torchvision.datasets.MNIST(root='./data', train=False, download=False, transform=transform) ## Step 3: create our dataloaders train_dl = torch.utils.data.DataLoader(train_ds, num_workers=NUM_WORKERS, shuffle=True, batch_size=BATCH_SIZE) valid_dl = torch.utils.data.DataLoader(valid_ds, num_workers=NUM_WORKERS, shuffle=False, batch_size=BATCH_SIZE) test_dl = torch.utils.data.DataLoader(test_ds, num_workers=NUM_WORKERS, shuffle=False, batch_size=BATCH_SIZE) ## Step 4: define our model and optimizer model = MNISTNetwork() criteron = LOSS # define our loss function optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM) ## define our table table = prettytable.PrettyTable() table.field_names = ['Epoch', 'Training Loss', 'Validation Accuracy'] if __name__ == "__main__": multiprocessing.freeze_support() # begin training process for e in range(EPOCH): model.train() running_loss = 0.0 for inputs, labels in train_dl: optimizer.zero_grad() outputs = model(inputs) loss = criteron(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() train_loss = round(running_loss/len(train_dl), 4) # evaluate on the test set model.eval() with torch.no_grad(): total, correct = 0, 0 for inputs, labels in valid_dl: outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() val_acc = round((correct/total)*100, 3) table.add_row([e, train_loss, val_acc]) print(f'Training Loss: {train_loss}, Validation Accuracy: {val_acc}') print(table) # evaluate on test set model.eval() with torch.no_grad(): total, correct = 0, 0 for inputs, labels in test_dl: outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() test_acc = round((correct/total)*100, 3) print(f'Test Accuracy: {test_acc}') torch.save(model.state_dict(), 'MNISTModel.pth')