mnist / train_conv.py
gaviego's picture
fix for train
854c6e9
raw history blame
No virus
1.65 kB
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from models import NetConv
import torch.nn.functional as F
# Training settings
batch_size = 64
test_batch_size = 1000
epochs = 20
lr = 0.01
momentum = 0.5
seed = 1
log_interval = 10
torch.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
# MNIST Dataset
train_dataset = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=test_batch_size, shuffle=False, **kwargs)
model = NetConv().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
# Training loop
for epoch in range(1, epochs + 1):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
torch.save(model,'mnist_conv.pth')