mnist_interface / model.py
equ1's picture
Update model.py
e4b9ec4
import os
import time
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
def get_mean_std(loader):
'''
Calculates mean and std of input images.
Args:
loader (torch.DataLoader): Loader with images
Returns:
mean (torch.Tensor): Mean of images in loader
std (torch.Tensor): Standard deviation of images in loader
'''
channels_sum, channels_squared_sum, num_batches = 0, 0, 0
for data, _ in loader:
channels_sum += torch.mean(data, dim=[0,2,3]) # mean across [no. of examples, height, width]
channels_squared_sum += torch.mean(data**2, dim=[0,2,3]) # squared mean across [no. of examples, height, width]
num_batches += 1
mean = channels_sum/num_batches
std = (channels_squared_sum/(num_batches-mean**2))**0.5
return mean, std
class Net(nn.Module):
'''
model definition
'''
def __init__(self):
super(Net, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=5),
nn.ReLU(),
)
self.layer2 = nn.Sequential(
nn.Conv2d(32, 32, kernel_size=5, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d((2, 2)),
nn.Dropout2d(0.25),
)
self.layer3 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3),
nn.ReLU(),
)
self.layer4 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d((2, 2)),
nn.Dropout2d(0.25),
nn.Flatten(),
)
self.layer5 = nn.Sequential(
nn.Linear(576, 256, bias=False),
nn.BatchNorm1d(256),
nn.ReLU(),
)
self.layer6 = nn.Sequential(
nn.Linear(256, 128, bias=False),
nn.BatchNorm1d(128),
nn.ReLU(),
)
self.layer7 = nn.Sequential(
nn.Linear(128, 84, bias=False),
nn.BatchNorm1d(84),
nn.ReLU(),
nn.Dropout(0.25),
)
self.layer8 = nn.Sequential(
nn.Linear(84, 10),
nn.LogSoftmax(dim=1),
)
def forward(self, x):
x = transforms.Normalize(mean, std)(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
x = self.layer7(x)
x = self.layer8(x)
return x
# downloads and loads MNIST train set
transform = transforms.Compose([transforms.ToTensor(), transforms.RandomAffine(degrees=10, translate=(0.1,0.1))])
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True, pin_memory=True)
# downloads and loads MNIST test set
val_data = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
val_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=False, pin_memory=True)
# uses GPU if available
if torch.cuda.is_available():
dev = "cuda:0"
else:
dev = "cpu"
device = torch.device(dev)
# gets mean and std of dataset
mean, std = get_mean_std(train_loader)
def run_model():
# defines parameters
model = Net().to(device=device)
optimizer = optim.Adam(model.parameters(), lr=0.1)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=2)
criterion = nn.NLLLoss()
# iterates through epochs
for epoch in range(30):
print(f"\nEpoch {epoch+1}/{30}.")
# train loop
model.train()
total_train_loss = 0
total_correct = 0
for i, (images, labels) in enumerate(tqdm(train_loader)):
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
total_train_loss += loss.item()
loss.backward()
optimizer.step()
# Calculates train accuracy
outputs_probs = nn.functional.softmax(
outputs, dim=1) # gets probabilities
for idx, preds in enumerate(outputs_probs):
# if label with max probability matches true label
if labels[idx] == torch.argmax(preds.data):
total_correct += 1
train_loss = total_train_loss/(i+1)
train_accuracy = total_correct/len(train_data)
print(f"Train set:- Loss: {train_loss}, Accuracy: {train_accuracy}.")
# saves model state
if not os.path.exists("./saved_models"):
os.mkdir("./saved_models")
torch.save(model.state_dict(), f"./saved_models/mnist-cnn-{time.time()}.pt")
# val loop
model.eval()
total_val_loss = 0
total_correct = 0
with torch.no_grad():
for i, (images, labels) in enumerate(tqdm(val_loader)):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
total_val_loss += loss.item()
outputs_probs = nn.functional.softmax(outputs, dim=1)
for idx, preds in enumerate(outputs_probs):
if labels[idx] == torch.argmax(preds.data):
total_correct += 1
val_loss = total_val_loss/(i+1)
val_accuracy = total_correct/len(val_data)
print(f"Val set:- Loss: {val_loss}, Accuracy: {val_accuracy}.")
# adjusts lr
scheduler.step(val_loss)