mnist / train_conv.py
gaviego's picture
fix for train
854c6e9
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from models import NetConv
import torch.nn.functional as F
# Training settings
batch_size = 64
test_batch_size = 1000
epochs = 20
lr = 0.01
momentum = 0.5
seed = 1
log_interval = 10
torch.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
# MNIST Dataset
train_dataset = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=test_batch_size, shuffle=False, **kwargs)
model = NetConv().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
# Training loop
for epoch in range(1, epochs + 1):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
torch.save(model,'mnist_conv.pth')