Sifal's picture
Upload 9 files
c7f0510 verified
raw
history blame
5.97 kB
import torch.nn as nn
from torchvision import transforms
from utils.data import PokemonDataModule
from utils.train import initialize_model, train_and_evaluate
import torch
import torch.optim as optim
import mlflow
import argparse
import random
# The shape of the images that the models expects
IMG_SHAPE = (224, 224)
def parser_args():
parser = argparse.ArgumentParser(description="Pokemon Classification")
parser.add_argument(
"--data_dir",
type=str,
default="./pokemonclassification/PokemonData",
help="Path to the data directory",
)
parser.add_argument(
"--indices_file",
type=str,
default="indices_60_32.pkl",
help="Path to the indices file",
)
parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
parser.add_argument(
"--train_batch_size", type=int, default=128, help="train Batch size"
)
parser.add_argument(
"--test_batch_size", type=int, default=512, help="test Batch size"
)
parser.add_argument(
"--model",
type=str,
choices=["resnet", "alexnet", "vgg", "squeezenet", "densenet"],
default="resnet",
help="Model to be used",
)
parser.add_argument(
"--feature_extract",
type=bool,
default=True,
help="whether to freeze the backbone or not",
)
parser.add_argument(
"--use_pretrained",
type=bool,
default=True,
help="whether to use pretrained model or not",
)
parser.add_argument(
"--experiment_id",
type=int,
default=0,
help="Experiment ID to log the results",
)
return parser.parse_args()
if __name__ == "__main__":
args = parser_args()
pokemon_dataset = PokemonDataModule(args.data_dir)
NUM_CLASSES = len(pokemon_dataset.class_names)
# Get class names
print(f"Number of classes: {NUM_CLASSES}")
# You can only the use precomputed means and vars if using the same indices file ('indices_60_32.pkl')
if "indices_60_32.pkl" in args.indices_file:
chanel_means = torch.tensor([0.6062, 0.5889, 0.5550])
chanel_vars = torch.tensor([0.3284, 0.3115, 0.3266])
stats = {"mean": chanel_means, "std": chanel_vars}
_ = pokemon_dataset.prepare_data(
indices_file=args.indices_file, get_stats=False
)
else:
stats = pokemon_dataset.prepare_data(
indices_file=args.indices_file, get_stats=True
)
print(f"Train dataset size: {len(pokemon_dataset.train_dataset)}")
print(f"Test dataset size: {len(pokemon_dataset.test_dataset)}")
# Transformations of data for testing
test_transform = transforms.Compose(
[
transforms.Resize(IMG_SHAPE),
transforms.ToTensor(), # Convert PIL images to tensors
transforms.Normalize(**stats), # Normalize images using mean and std
]
)
# Data augmentations for training
train_transform = transforms.Compose(
[
transforms.Resize(IMG_SHAPE),
transforms.RandomRotation(10),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(IMG_SHAPE, padding=4),
transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
transforms.ColorJitter(
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2
),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize(**stats),
]
)
# get dataloaders
trainloader, testloader = pokemon_dataset.get_dataloaders(
train_transform=train_transform,
test_transform=test_transform,
train_batch_size=args.train_batch_size,
test_batch_size=args.test_batch_size,
)
pokemon_dataset.plot_examples(testloader, stats=stats)
pokemon_dataset.plot_examples(trainloader, stats=stats)
# Try with a finetuning a resnet for example
model = initialize_model(
args.model,
NUM_CLASSES,
feature_extract=args.feature_extract,
use_pretrained=args.use_pretrained,
)
# Print the model we just instantiated
print(model)
# Model, criterion, optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with mlflow.start_run(
experiment_id=args.experiment_id,
run_name=f"{args.model}_{'finetuning' if not args.feature_extract else 'feature_extracting'}"
f"_{'pretrained' if args.use_pretrained else 'not_pretrained'}"
f"_{args.indices_file}_{random.randint(0, 1000)}",
) as run:
mlflow.log_param("epochs", args.epochs)
mlflow.log_param("lr", args.lr)
mlflow.log_param("train_batch_size", args.train_batch_size)
mlflow.log_param("test_batch_size", args.test_batch_size)
mlflow.log_param("model", args.model)
mlflow.log_param("feature_extract", args.feature_extract)
mlflow.log_param("use_pretrained", args.use_pretrained)
# Train and evaluate
history = train_and_evaluate(
model=model,
trainloader=trainloader,
testloader=testloader,
criterion=criterion,
optimizer=optimizer,
device=device,
epochs=args.epochs,
use_mlflow=True,
)
# Save the model
torch.save(model.state_dict(), f"pokemon_{args.model}.pth")
mlflow.log_artifact(f"pokemon_{args.model}.pth")
mlflow.end_run()