File size: 3,214 Bytes
3352589
 
 
 
 
 
 
6dc829b
373be07
3352589
 
 
 
6dc829b
 
 
 
3352589
373be07
 
3352589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373be07
 
 
 
 
 
 
 
3352589
373be07
3352589
 
 
 
 
 
 
373be07
 
3352589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dc829b
3352589
6dc829b
3352589
 
6dc829b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torch.nn as nn
import torch.optim as optim
from resnet_model import ResNet50
from data_utils import get_train_transform, get_test_transform, get_data_loaders
from train_test import train, test
from utils import save_checkpoint, load_checkpoint, plot_training_curves, plot_misclassified_samples
from torchsummary import summary
from torch.optim.lr_scheduler import OneCycleLR

def main():
    # Initialize model, loss function, and optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ResNet50()
    model = torch.nn.DataParallel(model)
    model = model.to(device)
    summary(model, input_size=(3, 224, 224))
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)


    # Load data
    train_transform = get_train_transform()
    test_transform = get_test_transform()
    trainloader, testloader = get_data_loaders(train_transform, test_transform)

    # Load checkpoint if it exists
    checkpoint_path = "checkpoint.pth"
    try:
        model, optimizer, start_epoch, _ = load_checkpoint(model, optimizer, checkpoint_path)
    except FileNotFoundError:
        print("No checkpoint found, starting from scratch.")
        start_epoch = 1

    # Store results for plotting
    results = []
    learning_rates = []

    # Set One-Cycle LR scheduler
    num_epochs = 10
    steps_per_epoch = len(trainloader)
    lr_max = 1e-2  

    scheduler = OneCycleLR(optimizer, max_lr=lr_max, epochs=num_epochs, steps_per_epoch=steps_per_epoch)


    # Training loop
    for epoch in range(start_epoch+1, start_epoch + num_epochs):
        train_accuracy1, train_accuracy5, train_loss = train(model, device, trainloader, optimizer, criterion, epoch)
        test_accuracy1, test_accuracy5, test_loss, misclassified_images, misclassified_labels, misclassified_preds = test(model, device, testloader, criterion)
        print(f'Epoch {epoch} | Train Top-1 Acc: {train_accuracy1:.2f} | Test Top-1 Acc: {test_accuracy1:.2f}')

        # Append results for this epoch
        results.append((epoch, train_accuracy1, train_accuracy5, test_accuracy1, test_accuracy5, train_loss, test_loss))
        learning_rates.append(optimizer.param_groups[0]['lr'])
        
        scheduler.step()
        # Save checkpoint
        save_checkpoint(model, optimizer, epoch, test_loss, checkpoint_path)

    # Extract results for plotting
    epochs = [r[0] for r in results]
    train_acc1 = [r[1] for r in results]
    train_acc5 = [r[2] for r in results]
    test_acc1 = [r[3] for r in results]
    test_acc5 = [r[4] for r in results]
    train_losses = [r[5] for r in results]
    test_losses = [r[6] for r in results]

    # Plot training curves
    plot_training_curves(epochs, train_acc1, test_acc1, train_acc5, test_acc5, train_losses, test_losses, learning_rates)

    # Plot misclassified samples
    '''

    plot_misclassified_samples(misclassified_images, misclassified_labels, misclassified_preds, classes=['class1', 'class2', ...])  # Replace with actual class names

    '''

if __name__ == '__main__':
    main()