Spaces:
Runtime error
Runtime error
import torch | |
import torchvision | |
import multiprocessing, prettytable | |
import torchvision.transforms as transforms | |
from neural_network import MNISTNetwork | |
# hyperparameters | |
BATCH_SIZE = 64 | |
NUM_WORKERS = 2 | |
EPOCH = 15 | |
LEARNING_RATE = 0.01 | |
MOMENTUM = 0.5 | |
LOSS = torch.nn.CrossEntropyLoss() | |
## Step 1: define our transforms | |
transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5), (0.5)) | |
] | |
) | |
## Step 2: get our datasets | |
full_ds = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform) | |
train_size = int(0.8 * len(full_ds)) # Use 80% of the data for training | |
val_size = len(full_ds) - train_size # Use the remaining 20% for validation | |
train_ds, valid_ds = torch.utils.data.random_split(full_ds, [train_size, val_size]) | |
test_ds = torchvision.datasets.MNIST(root='./data', train=False, download=False, transform=transform) | |
## Step 3: create our dataloaders | |
train_dl = torch.utils.data.DataLoader(train_ds, num_workers=NUM_WORKERS, shuffle=True, batch_size=BATCH_SIZE) | |
valid_dl = torch.utils.data.DataLoader(valid_ds, num_workers=NUM_WORKERS, shuffle=False, batch_size=BATCH_SIZE) | |
test_dl = torch.utils.data.DataLoader(test_ds, num_workers=NUM_WORKERS, shuffle=False, batch_size=BATCH_SIZE) | |
## Step 4: define our model and optimizer | |
model = MNISTNetwork() | |
criteron = LOSS # define our loss function | |
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM) | |
## define our table | |
table = prettytable.PrettyTable() | |
table.field_names = ['Epoch', 'Training Loss', 'Validation Accuracy'] | |
if __name__ == "__main__": | |
multiprocessing.freeze_support() | |
# begin training process | |
for e in range(EPOCH): | |
model.train() | |
running_loss = 0.0 | |
for inputs, labels in train_dl: | |
optimizer.zero_grad() | |
outputs = model(inputs) | |
loss = criteron(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
train_loss = round(running_loss/len(train_dl), 4) | |
# evaluate on the test set | |
model.eval() | |
with torch.no_grad(): | |
total, correct = 0, 0 | |
for inputs, labels in valid_dl: | |
outputs = model(inputs) | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
val_acc = round((correct/total)*100, 3) | |
table.add_row([e, train_loss, val_acc]) | |
print(f'Training Loss: {train_loss}, Validation Accuracy: {val_acc}') | |
print(table) | |
# evaluate on test set | |
model.eval() | |
with torch.no_grad(): | |
total, correct = 0, 0 | |
for inputs, labels in test_dl: | |
outputs = model(inputs) | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
test_acc = round((correct/total)*100, 3) | |
print(f'Test Accuracy: {test_acc}') | |
torch.save(model.state_dict(), 'MNISTModel.pth') | |