File size: 740 Bytes
a0f925f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch

from models import CNN
from dataset import DatasetMNIST, download_mnist
from train import get_dataloaders, train_net_manually, train_net_lightning


def main(device):
    mnist = download_mnist("downloads/mnist/")
    dataset, test_data = DatasetMNIST(*mnist["train"]), DatasetMNIST(*mnist["test"])
    train_loader, validate_loader, test_loader = get_dataloaders(dataset, test_data)

    # Training manually
    net = CNN(input_channels=1, num_classes=10).to(device)
    optim = torch.optim.Adam(net.parameters(), lr=1e-4)
    loss_fn = torch.nn.CrossEntropyLoss()
    max_epochs = 1

    train_net_manually(net, optim, loss_fn, train_loader, validate_loader, max_epochs, device)


if __name__ == "__main__":
    main("cpu")