abhishekrs4's picture
updated modeling module
3d4a19d
raw
history blame contribute delete
No virus
12 kB
import os
import sys
import time
import torch
import wandb
import logging
import argparse
import torchvision
import torch.nn.functional as F
from torch.optim import SGD, AdamW
from typing import Union, List, Tuple
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import _LRScheduler
from dataset import split_dataset, get_dataloaders_for_training
from models import SimpleCNN, SimpleResNet, ComplexResNet, ComplexResNetV2
class PolynomialLR(_LRScheduler):
def __init__(
self,
optimizer: Union[SGD, AdamW],
max_epochs: int,
power: float = 0.9,
last_epoch: int = -1,
min_lr: float = 1e-6,
):
self.power = power
self.max_epochs = max_epochs
self.min_lr = min_lr # avoid zero lr
super(PolynomialLR, self).__init__(optimizer, last_epoch)
def get_lr(self) -> float:
return [
max(
base_lr * (1 - self.last_epoch / self.max_epochs) ** self.power,
self.min_lr,
)
for base_lr in self.base_lrs
]
def train(
model: Union[SimpleCNN, SimpleResNet, ComplexResNet, ComplexResNetV2],
optimizer: Union[SGD, AdamW],
criterion: CrossEntropyLoss,
train_loader: DataLoader,
device: torch.device,
) -> Tuple[float, float]:
"""
---------
Arguments
---------
model: object
an object of type torch model
optimizer: object
an object of type torch Optimizer
criterion: object
an object of type torch criterion function
train_loader: object
an object of type torch dataloader
device: object
an object of type torch device
-------
Returns
-------
(train_loss, train_acc) : tuple
a tuple of training loss and training accuracy
"""
model.to(device)
model.train()
train_running_loss = 0.0
train_running_correct = 0
num_train_samples = len(train_loader.dataset)
num_train_batches = len(train_loader)
for data, label in train_loader:
data = data.to(device, dtype=torch.float)
label = label.to(device, dtype=torch.long)
optimizer.zero_grad()
logits = model(data)
loss = criterion(logits, label)
train_running_loss += loss.item()
pred_label = torch.argmax(logits, dim=1)
train_running_correct += (pred_label == label).sum().item()
loss.backward()
optimizer.step()
train_loss = train_running_loss / num_train_batches
train_acc = 100.0 * train_running_correct / num_train_samples
return train_loss, train_acc
def validate(
model: Union[SimpleCNN, SimpleResNet, ComplexResNet, ComplexResNetV2],
criterion: CrossEntropyLoss,
validation_loader: DataLoader,
device: torch.device,
) -> Tuple[float, float]:
"""
---------
Arguments
---------
model: object
an object of type torch model
criterion: object
an object of type torch criterion function
validation_loader: object
an object of type torch dataloader
device: object
an object of type torch device
-------
Returns
-------
(validation_loss, validation_acc) : tuple
a tuple of validation loss and validation accuracy
"""
model.to(device)
model.eval()
validation_running_loss = 0.0
validation_running_correct = 0
num_validation_samples = len(validation_loader.dataset)
num_validation_batches = len(validation_loader)
with torch.no_grad():
for data, label in validation_loader:
data = data.to(device, dtype=torch.float)
label = label.to(device, dtype=torch.long)
logits = model(data)
loss = criterion(logits, label)
validation_running_loss += loss.item()
pred_label = torch.argmax(logits, dim=1)
# logging.info(logits, pred_label, label)
validation_running_correct += (pred_label == label).sum().item()
validation_loss = validation_running_loss / num_validation_batches
validation_acc = 100.0 * validation_running_correct / num_validation_samples
return validation_loss, validation_acc
def train_classifier(ARGS: argparse.Namespace) -> None:
logging.basicConfig(level=logging.INFO)
wandb.login()
if torch.cuda.is_available():
device = torch.device("cuda")
else:
logging.info("CUDA device not found, so exiting....")
sys.exit(0)
logging.info("Training a CNN model for the Overhead MNIST dataset")
train_x = []
train_y = []
dir_train = os.path.join(ARGS.dir_dataset, "train")
list_sub_dirs = sorted(os.listdir(dir_train))
num_classes = len(list_sub_dirs)
for sub_dir_idx in range(num_classes):
temp_train_x = os.listdir(os.path.join(dir_train, list_sub_dirs[sub_dir_idx]))
temp_train_x = [
os.path.join(list_sub_dirs[sub_dir_idx], f) for f in temp_train_x
]
temp_train_y = [sub_dir_idx] * len(temp_train_x)
train_x = train_x + temp_train_x
train_y = train_y + temp_train_y
(
train_x,
validation_x,
train_y,
validation_y,
) = split_dataset(train_x, train_y)
num_train_samples = len(train_x)
num_validation_samples = len(validation_x)
train_loader, validation_loader = get_dataloaders_for_training(
train_x,
train_y,
validation_x,
validation_y,
dir_images=dir_train,
batch_size=ARGS.batch_size,
)
dir_model = os.path.join(ARGS.dir_model, ARGS.model_type)
if not os.path.isdir(dir_model):
logging.info(f"Creating directory: {dir_model}")
os.makedirs(dir_model)
logging.info(
f"Num train samples: {num_train_samples}, num validation samples: {num_validation_samples}"
)
logging.info(f"Num classes: {num_classes}, model_type: {ARGS.model_type}")
if ARGS.model_type == "simple_cnn":
model = SimpleCNN(num_classes=num_classes)
elif ARGS.model_type == "simple_resnet":
model = SimpleResNet(num_classes=num_classes)
elif ARGS.model_type == "medium_simple_resnet":
model = SimpleResNet(
list_num_res_units_per_block=[4, 4], num_classes=num_classes
)
elif ARGS.model_type == "deep_simple_resnet":
model = SimpleResNet(
list_num_res_units_per_block=[6, 6], num_classes=num_classes
)
elif ARGS.model_type == "complex_resnet":
model = ComplexResNet(
list_num_res_units_per_block=[4, 4, 4], num_classes=num_classes
)
elif ARGS.model_type == "complex_resnet_v2":
model = ComplexResNetV2(
list_num_res_units_per_block=[4, 4, 4], num_classes=num_classes
)
else:
logging.info(f"Unidentified option for arg (model_type): {ARGS.model_type}")
model.to(device)
if ARGS.optimizer_type == "sgd":
optimizer = torch.optim.SGD(
model.parameters(),
lr=ARGS.learning_rate,
weight_decay=ARGS.weight_decay,
momentum=0.9,
)
else:
optimizer = torch.optim.AdamW(
model.parameters(), lr=ARGS.learning_rate, weight_decay=ARGS.weight_decay
)
if ARGS.lr_scheduler_type == "poly":
lr_scheduler = PolynomialLR(
optimizer,
ARGS.num_epochs + 1,
power=0.75,
)
criterion = CrossEntropyLoss()
config = {
"dataset": "Overhead-MNIST",
"architecture": "CNN",
"model_type": ARGS.model_type,
"optimizer": ARGS.optimizer_type,
"lr_scheduler": ARGS.lr_scheduler_type,
"learning_rate": ARGS.learning_rate,
"num_epochs": ARGS.num_epochs,
"batch_size": ARGS.batch_size,
"weight_decay": ARGS.weight_decay,
}
best_validation_acc = 0
logging.info(
f"Training the Overhead MNIST image classification model started, model_type: {ARGS.model_type}"
)
with wandb.init(project="overhead-mnist-model", config=config):
for epoch in range(1, ARGS.num_epochs + 1):
time_start = time.time()
train_loss, train_acc = train(
model, optimizer, criterion, train_loader, device
)
validation_loss, validation_acc = validate(
model, criterion, validation_loader, device
)
time_end = time.time()
logging.info(
f"Epoch: {epoch}/{ARGS.num_epochs}, time: {time_end-time_start:.4f} sec."
)
logging.info(
f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}"
)
logging.info(
f"Validation Loss: {validation_loss:.4f}, Validation Accuracy: {validation_acc:.4f}\n"
)
wandb.log(
{
"train_loss": train_loss,
"train_acc": train_acc,
"validation_loss": validation_loss,
"validation_acc": validation_acc,
},
step=epoch,
)
if validation_acc >= best_validation_acc:
best_validation_acc = validation_acc
torch.save(
model.state_dict(), os.path.join(dir_model, f"{ARGS.model_type}.pt")
)
wandb.save(os.path.join(dir_model, f"{ARGS.model_type}.pt"))
if ARGS.lr_scheduler_type == "poly":
lr_scheduler.step()
logging.info("Training the Overhead MNIST image classification model complete!!!!")
return
def main() -> None:
learning_rate = 1e-3
weight_decay = 5e-6
batch_size = 64
num_epochs = 100
model_type = "simple_cnn"
dir_dataset = "/home/abhishek/Desktop/datasets/overhead_mnist/version2"
dir_model = "trained_models"
lr_scheduler_type = "poly"
optimizer_type = "adam"
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--learning_rate",
default=learning_rate,
type=float,
help="learning rate to use for training",
)
parser.add_argument(
"--weight_decay",
default=weight_decay,
type=float,
help="weight decay to use for training",
)
parser.add_argument(
"--batch_size",
default=batch_size,
type=int,
help="batch size to use for training",
)
parser.add_argument(
"--num_epochs",
default=num_epochs,
type=int,
help="num epochs to train the model",
)
parser.add_argument(
"--dir_dataset",
default=dir_dataset,
type=str,
help="full directory path to dataset containing images",
)
parser.add_argument(
"--dir_model",
default=dir_model,
type=str,
help="full directory path where model needs to be saved",
)
parser.add_argument(
"--model_type",
default=model_type,
type=str,
choices=[
"simple_cnn",
"simple_resnet",
"medium_simple_resnet",
"deep_simple_resnet",
"complex_resnet",
"complex_resnet_v2",
],
help="model type to be trained",
)
parser.add_argument(
"--lr_scheduler_type",
default=lr_scheduler_type,
type=str,
choices=["none", "poly"],
help="learning rate scheduler to be used for training",
)
parser.add_argument(
"--optimizer_type",
default=optimizer_type,
type=str,
choices=["adam", "sgd"],
help="optimizer to be used for training",
)
ARGS, unparsed = parser.parse_known_args()
train_classifier(ARGS)
return
if __name__ == "__main__":
main()