File size: 5,971 Bytes
c7f0510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import torch.nn as nn
from torchvision import transforms
from utils.data import PokemonDataModule
from utils.train import initialize_model, train_and_evaluate
import torch
import torch.optim as optim
import mlflow
import argparse
import random

# The shape of the images that the models expects
IMG_SHAPE = (224, 224)


def parser_args():
    parser = argparse.ArgumentParser(description="Pokemon Classification")
    parser.add_argument(
        "--data_dir",
        type=str,
        default="./pokemonclassification/PokemonData",
        help="Path to the data directory",
    )
    parser.add_argument(
        "--indices_file",
        type=str,
        default="indices_60_32.pkl",
        help="Path to the indices file",
    )
    parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
    parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
    parser.add_argument(
        "--train_batch_size", type=int, default=128, help="train Batch size"
    )
    parser.add_argument(
        "--test_batch_size", type=int, default=512, help="test Batch size"
    )
    parser.add_argument(
        "--model",
        type=str,
        choices=["resnet", "alexnet", "vgg", "squeezenet", "densenet"],
        default="resnet",
        help="Model to be used",
    )
    parser.add_argument(
        "--feature_extract",
        type=bool,
        default=True,
        help="whether to freeze the backbone or not",
    )
    parser.add_argument(
        "--use_pretrained",
        type=bool,
        default=True,
        help="whether to use pretrained model or not",
    )
    parser.add_argument(
        "--experiment_id",
        type=int,
        default=0,
        help="Experiment ID to log the results",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parser_args()

    pokemon_dataset = PokemonDataModule(args.data_dir)
    NUM_CLASSES = len(pokemon_dataset.class_names)

    # Get class names
    print(f"Number of classes: {NUM_CLASSES}")

    # You can only the use precomputed means and vars if using the same indices file ('indices_60_32.pkl')
    if "indices_60_32.pkl" in args.indices_file:
        chanel_means = torch.tensor([0.6062, 0.5889, 0.5550])
        chanel_vars = torch.tensor([0.3284, 0.3115, 0.3266])
        stats = {"mean": chanel_means, "std": chanel_vars}
        _ = pokemon_dataset.prepare_data(
            indices_file=args.indices_file, get_stats=False
        )
    else:
        stats = pokemon_dataset.prepare_data(
            indices_file=args.indices_file, get_stats=True
        )

    print(f"Train dataset size: {len(pokemon_dataset.train_dataset)}")
    print(f"Test dataset size: {len(pokemon_dataset.test_dataset)}")

    # Transformations of data for testing
    test_transform = transforms.Compose(
        [
            transforms.Resize(IMG_SHAPE),
            transforms.ToTensor(),  # Convert PIL images to tensors
            transforms.Normalize(**stats),  # Normalize images using mean and std
        ]
    )

    # Data augmentations for training
    train_transform = transforms.Compose(
        [
            transforms.Resize(IMG_SHAPE),
            transforms.RandomRotation(10),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(IMG_SHAPE, padding=4),
            transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            transforms.ColorJitter(
                brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2
            ),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(**stats),
        ]
    )

    # get dataloaders
    trainloader, testloader = pokemon_dataset.get_dataloaders(
        train_transform=train_transform,
        test_transform=test_transform,
        train_batch_size=args.train_batch_size,
        test_batch_size=args.test_batch_size,
    )

    pokemon_dataset.plot_examples(testloader, stats=stats)

    pokemon_dataset.plot_examples(trainloader, stats=stats)

    # Try with a finetuning a resnet for example
    model = initialize_model(
        args.model,
        NUM_CLASSES,
        feature_extract=args.feature_extract,
        use_pretrained=args.use_pretrained,
    )

    # Print the model we just instantiated
    print(model)

    # Model, criterion, optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    with mlflow.start_run(
        experiment_id=args.experiment_id,
        run_name=f"{args.model}_{'finetuning' if not args.feature_extract else 'feature_extracting'}"
        f"_{'pretrained' if args.use_pretrained else 'not_pretrained'}"
        f"_{args.indices_file}_{random.randint(0, 1000)}",
    ) as run:
        mlflow.log_param("epochs", args.epochs)
        mlflow.log_param("lr", args.lr)
        mlflow.log_param("train_batch_size", args.train_batch_size)
        mlflow.log_param("test_batch_size", args.test_batch_size)
        mlflow.log_param("model", args.model)
        mlflow.log_param("feature_extract", args.feature_extract)
        mlflow.log_param("use_pretrained", args.use_pretrained)

        # Train and evaluate
        history = train_and_evaluate(
            model=model,
            trainloader=trainloader,
            testloader=testloader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            epochs=args.epochs,
            use_mlflow=True,
        )
        # Save the model
        torch.save(model.state_dict(), f"pokemon_{args.model}.pth")
        mlflow.log_artifact(f"pokemon_{args.model}.pth")
        mlflow.end_run()