|
import torch.nn as nn
|
|
from torchvision import transforms
|
|
from utils.data import PokemonDataModule
|
|
from utils.train import initialize_model, train_and_evaluate
|
|
import torch
|
|
import torch.optim as optim
|
|
import mlflow
|
|
import argparse
|
|
import random
|
|
|
|
|
|
IMG_SHAPE = (224, 224)
|
|
|
|
|
|
def parser_args():
|
|
parser = argparse.ArgumentParser(description="Pokemon Classification")
|
|
parser.add_argument(
|
|
"--data_dir",
|
|
type=str,
|
|
default="./pokemonclassification/PokemonData",
|
|
help="Path to the data directory",
|
|
)
|
|
parser.add_argument(
|
|
"--indices_file",
|
|
type=str,
|
|
default="indices_60_32.pkl",
|
|
help="Path to the indices file",
|
|
)
|
|
parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
|
|
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
|
|
parser.add_argument(
|
|
"--train_batch_size", type=int, default=128, help="train Batch size"
|
|
)
|
|
parser.add_argument(
|
|
"--test_batch_size", type=int, default=512, help="test Batch size"
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
choices=["resnet", "alexnet", "vgg", "squeezenet", "densenet"],
|
|
default="resnet",
|
|
help="Model to be used",
|
|
)
|
|
parser.add_argument(
|
|
"--feature_extract",
|
|
type=bool,
|
|
default=True,
|
|
help="whether to freeze the backbone or not",
|
|
)
|
|
parser.add_argument(
|
|
"--use_pretrained",
|
|
type=bool,
|
|
default=True,
|
|
help="whether to use pretrained model or not",
|
|
)
|
|
parser.add_argument(
|
|
"--experiment_id",
|
|
type=int,
|
|
default=0,
|
|
help="Experiment ID to log the results",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parser_args()
|
|
|
|
pokemon_dataset = PokemonDataModule(args.data_dir)
|
|
NUM_CLASSES = len(pokemon_dataset.class_names)
|
|
|
|
|
|
print(f"Number of classes: {NUM_CLASSES}")
|
|
|
|
|
|
if "indices_60_32.pkl" in args.indices_file:
|
|
chanel_means = torch.tensor([0.6062, 0.5889, 0.5550])
|
|
chanel_vars = torch.tensor([0.3284, 0.3115, 0.3266])
|
|
stats = {"mean": chanel_means, "std": chanel_vars}
|
|
_ = pokemon_dataset.prepare_data(
|
|
indices_file=args.indices_file, get_stats=False
|
|
)
|
|
else:
|
|
stats = pokemon_dataset.prepare_data(
|
|
indices_file=args.indices_file, get_stats=True
|
|
)
|
|
|
|
print(f"Train dataset size: {len(pokemon_dataset.train_dataset)}")
|
|
print(f"Test dataset size: {len(pokemon_dataset.test_dataset)}")
|
|
|
|
|
|
test_transform = transforms.Compose(
|
|
[
|
|
transforms.Resize(IMG_SHAPE),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(**stats),
|
|
]
|
|
)
|
|
|
|
|
|
train_transform = transforms.Compose(
|
|
[
|
|
transforms.Resize(IMG_SHAPE),
|
|
transforms.RandomRotation(10),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.RandomCrop(IMG_SHAPE, padding=4),
|
|
transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
|
|
transforms.ColorJitter(
|
|
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2
|
|
),
|
|
transforms.RandomGrayscale(p=0.2),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(**stats),
|
|
]
|
|
)
|
|
|
|
|
|
trainloader, testloader = pokemon_dataset.get_dataloaders(
|
|
train_transform=train_transform,
|
|
test_transform=test_transform,
|
|
train_batch_size=args.train_batch_size,
|
|
test_batch_size=args.test_batch_size,
|
|
)
|
|
|
|
pokemon_dataset.plot_examples(testloader, stats=stats)
|
|
|
|
pokemon_dataset.plot_examples(trainloader, stats=stats)
|
|
|
|
|
|
model = initialize_model(
|
|
args.model,
|
|
NUM_CLASSES,
|
|
feature_extract=args.feature_extract,
|
|
use_pretrained=args.use_pretrained,
|
|
)
|
|
|
|
|
|
print(model)
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
with mlflow.start_run(
|
|
experiment_id=args.experiment_id,
|
|
run_name=f"{args.model}_{'finetuning' if not args.feature_extract else 'feature_extracting'}"
|
|
f"_{'pretrained' if args.use_pretrained else 'not_pretrained'}"
|
|
f"_{args.indices_file}_{random.randint(0, 1000)}",
|
|
) as run:
|
|
mlflow.log_param("epochs", args.epochs)
|
|
mlflow.log_param("lr", args.lr)
|
|
mlflow.log_param("train_batch_size", args.train_batch_size)
|
|
mlflow.log_param("test_batch_size", args.test_batch_size)
|
|
mlflow.log_param("model", args.model)
|
|
mlflow.log_param("feature_extract", args.feature_extract)
|
|
mlflow.log_param("use_pretrained", args.use_pretrained)
|
|
|
|
|
|
history = train_and_evaluate(
|
|
model=model,
|
|
trainloader=trainloader,
|
|
testloader=testloader,
|
|
criterion=criterion,
|
|
optimizer=optimizer,
|
|
device=device,
|
|
epochs=args.epochs,
|
|
use_mlflow=True,
|
|
)
|
|
|
|
torch.save(model.state_dict(), f"pokemon_{args.model}.pth")
|
|
mlflow.log_artifact(f"pokemon_{args.model}.pth")
|
|
mlflow.end_run()
|
|
|