File size: 5,971 Bytes
c7f0510 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import torch.nn as nn
from torchvision import transforms
from utils.data import PokemonDataModule
from utils.train import initialize_model, train_and_evaluate
import torch
import torch.optim as optim
import mlflow
import argparse
import random
# The shape of the images that the models expects
IMG_SHAPE = (224, 224)
def parser_args():
parser = argparse.ArgumentParser(description="Pokemon Classification")
parser.add_argument(
"--data_dir",
type=str,
default="./pokemonclassification/PokemonData",
help="Path to the data directory",
)
parser.add_argument(
"--indices_file",
type=str,
default="indices_60_32.pkl",
help="Path to the indices file",
)
parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
parser.add_argument(
"--train_batch_size", type=int, default=128, help="train Batch size"
)
parser.add_argument(
"--test_batch_size", type=int, default=512, help="test Batch size"
)
parser.add_argument(
"--model",
type=str,
choices=["resnet", "alexnet", "vgg", "squeezenet", "densenet"],
default="resnet",
help="Model to be used",
)
parser.add_argument(
"--feature_extract",
type=bool,
default=True,
help="whether to freeze the backbone or not",
)
parser.add_argument(
"--use_pretrained",
type=bool,
default=True,
help="whether to use pretrained model or not",
)
parser.add_argument(
"--experiment_id",
type=int,
default=0,
help="Experiment ID to log the results",
)
return parser.parse_args()
if __name__ == "__main__":
args = parser_args()
pokemon_dataset = PokemonDataModule(args.data_dir)
NUM_CLASSES = len(pokemon_dataset.class_names)
# Get class names
print(f"Number of classes: {NUM_CLASSES}")
# You can only the use precomputed means and vars if using the same indices file ('indices_60_32.pkl')
if "indices_60_32.pkl" in args.indices_file:
chanel_means = torch.tensor([0.6062, 0.5889, 0.5550])
chanel_vars = torch.tensor([0.3284, 0.3115, 0.3266])
stats = {"mean": chanel_means, "std": chanel_vars}
_ = pokemon_dataset.prepare_data(
indices_file=args.indices_file, get_stats=False
)
else:
stats = pokemon_dataset.prepare_data(
indices_file=args.indices_file, get_stats=True
)
print(f"Train dataset size: {len(pokemon_dataset.train_dataset)}")
print(f"Test dataset size: {len(pokemon_dataset.test_dataset)}")
# Transformations of data for testing
test_transform = transforms.Compose(
[
transforms.Resize(IMG_SHAPE),
transforms.ToTensor(), # Convert PIL images to tensors
transforms.Normalize(**stats), # Normalize images using mean and std
]
)
# Data augmentations for training
train_transform = transforms.Compose(
[
transforms.Resize(IMG_SHAPE),
transforms.RandomRotation(10),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(IMG_SHAPE, padding=4),
transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
transforms.ColorJitter(
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2
),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize(**stats),
]
)
# get dataloaders
trainloader, testloader = pokemon_dataset.get_dataloaders(
train_transform=train_transform,
test_transform=test_transform,
train_batch_size=args.train_batch_size,
test_batch_size=args.test_batch_size,
)
pokemon_dataset.plot_examples(testloader, stats=stats)
pokemon_dataset.plot_examples(trainloader, stats=stats)
# Try with a finetuning a resnet for example
model = initialize_model(
args.model,
NUM_CLASSES,
feature_extract=args.feature_extract,
use_pretrained=args.use_pretrained,
)
# Print the model we just instantiated
print(model)
# Model, criterion, optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with mlflow.start_run(
experiment_id=args.experiment_id,
run_name=f"{args.model}_{'finetuning' if not args.feature_extract else 'feature_extracting'}"
f"_{'pretrained' if args.use_pretrained else 'not_pretrained'}"
f"_{args.indices_file}_{random.randint(0, 1000)}",
) as run:
mlflow.log_param("epochs", args.epochs)
mlflow.log_param("lr", args.lr)
mlflow.log_param("train_batch_size", args.train_batch_size)
mlflow.log_param("test_batch_size", args.test_batch_size)
mlflow.log_param("model", args.model)
mlflow.log_param("feature_extract", args.feature_extract)
mlflow.log_param("use_pretrained", args.use_pretrained)
# Train and evaluate
history = train_and_evaluate(
model=model,
trainloader=trainloader,
testloader=testloader,
criterion=criterion,
optimizer=optimizer,
device=device,
epochs=args.epochs,
use_mlflow=True,
)
# Save the model
torch.save(model.state_dict(), f"pokemon_{args.model}.pth")
mlflow.log_artifact(f"pokemon_{args.model}.pth")
mlflow.end_run()
|