DigitClassifier / train_model.py
hkanumilli's picture
updating with final model
dcb34da
raw
history blame
3.12 kB
import torch
import torchvision
import multiprocessing, prettytable
import torchvision.transforms as transforms
from neural_network import MNISTNetwork
# hyperparameters
BATCH_SIZE = 64
NUM_WORKERS = 2
EPOCH = 15
LEARNING_RATE = 0.01
MOMENTUM = 0.5
LOSS = torch.nn.CrossEntropyLoss()
## Step 1: define our transforms
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))
]
)
## Step 2: get our datasets
full_ds = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform)
train_size = int(0.8 * len(full_ds)) # Use 80% of the data for training
val_size = len(full_ds) - train_size # Use the remaining 20% for validation
train_ds, valid_ds = torch.utils.data.random_split(full_ds, [train_size, val_size])
test_ds = torchvision.datasets.MNIST(root='./data', train=False, download=False, transform=transform)
## Step 3: create our dataloaders
train_dl = torch.utils.data.DataLoader(train_ds, num_workers=NUM_WORKERS, shuffle=True, batch_size=BATCH_SIZE)
valid_dl = torch.utils.data.DataLoader(valid_ds, num_workers=NUM_WORKERS, shuffle=False, batch_size=BATCH_SIZE)
test_dl = torch.utils.data.DataLoader(test_ds, num_workers=NUM_WORKERS, shuffle=False, batch_size=BATCH_SIZE)
## Step 4: define our model and optimizer
model = MNISTNetwork()
criteron = LOSS # define our loss function
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
## define our table
table = prettytable.PrettyTable()
table.field_names = ['Epoch', 'Training Loss', 'Validation Accuracy']
if __name__ == "__main__":
multiprocessing.freeze_support()
# begin training process
for e in range(EPOCH):
model.train()
running_loss = 0.0
for inputs, labels in train_dl:
optimizer.zero_grad()
outputs = model(inputs)
loss = criteron(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
train_loss = round(running_loss/len(train_dl), 4)
# evaluate on the test set
model.eval()
with torch.no_grad():
total, correct = 0, 0
for inputs, labels in valid_dl:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
val_acc = round((correct/total)*100, 3)
table.add_row([e, train_loss, val_acc])
print(f'Training Loss: {train_loss}, Validation Accuracy: {val_acc}')
print(table)
# evaluate on test set
model.eval()
with torch.no_grad():
total, correct = 0, 0
for inputs, labels in test_dl:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
test_acc = round((correct/total)*100, 3)
print(f'Test Accuracy: {test_acc}')
torch.save(model.state_dict(), 'MNISTModel.pth')