lenet_mnist / main.py
rzimmerdev's picture
Fixed devices and demo
a0f925f
raw
history blame
740 Bytes
import torch
from models import CNN
from dataset import DatasetMNIST, download_mnist
from train import get_dataloaders, train_net_manually, train_net_lightning
def main(device):
mnist = download_mnist("downloads/mnist/")
dataset, test_data = DatasetMNIST(*mnist["train"]), DatasetMNIST(*mnist["test"])
train_loader, validate_loader, test_loader = get_dataloaders(dataset, test_data)
# Training manually
net = CNN(input_channels=1, num_classes=10).to(device)
optim = torch.optim.Adam(net.parameters(), lr=1e-4)
loss_fn = torch.nn.CrossEntropyLoss()
max_epochs = 1
train_net_manually(net, optim, loss_fn, train_loader, validate_loader, max_epochs, device)
if __name__ == "__main__":
main("cpu")