Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import time | |
import torch | |
import wandb | |
import logging | |
import argparse | |
import torchvision | |
import torch.nn.functional as F | |
from torch.optim import SGD, AdamW | |
from typing import Union, List, Tuple | |
from torch.nn import CrossEntropyLoss | |
from torch.utils.data import DataLoader | |
from torch.optim.lr_scheduler import _LRScheduler | |
from dataset import split_dataset, get_dataloaders_for_training | |
from models import SimpleCNN, SimpleResNet, ComplexResNet, ComplexResNetV2 | |
class PolynomialLR(_LRScheduler): | |
def __init__( | |
self, | |
optimizer: Union[SGD, AdamW], | |
max_epochs: int, | |
power: float = 0.9, | |
last_epoch: int = -1, | |
min_lr: float = 1e-6, | |
): | |
self.power = power | |
self.max_epochs = max_epochs | |
self.min_lr = min_lr # avoid zero lr | |
super(PolynomialLR, self).__init__(optimizer, last_epoch) | |
def get_lr(self) -> float: | |
return [ | |
max( | |
base_lr * (1 - self.last_epoch / self.max_epochs) ** self.power, | |
self.min_lr, | |
) | |
for base_lr in self.base_lrs | |
] | |
def train( | |
model: Union[SimpleCNN, SimpleResNet, ComplexResNet, ComplexResNetV2], | |
optimizer: Union[SGD, AdamW], | |
criterion: CrossEntropyLoss, | |
train_loader: DataLoader, | |
device: torch.device, | |
) -> Tuple[float, float]: | |
""" | |
--------- | |
Arguments | |
--------- | |
model: object | |
an object of type torch model | |
optimizer: object | |
an object of type torch Optimizer | |
criterion: object | |
an object of type torch criterion function | |
train_loader: object | |
an object of type torch dataloader | |
device: object | |
an object of type torch device | |
------- | |
Returns | |
------- | |
(train_loss, train_acc) : tuple | |
a tuple of training loss and training accuracy | |
""" | |
model.to(device) | |
model.train() | |
train_running_loss = 0.0 | |
train_running_correct = 0 | |
num_train_samples = len(train_loader.dataset) | |
num_train_batches = len(train_loader) | |
for data, label in train_loader: | |
data = data.to(device, dtype=torch.float) | |
label = label.to(device, dtype=torch.long) | |
optimizer.zero_grad() | |
logits = model(data) | |
loss = criterion(logits, label) | |
train_running_loss += loss.item() | |
pred_label = torch.argmax(logits, dim=1) | |
train_running_correct += (pred_label == label).sum().item() | |
loss.backward() | |
optimizer.step() | |
train_loss = train_running_loss / num_train_batches | |
train_acc = 100.0 * train_running_correct / num_train_samples | |
return train_loss, train_acc | |
def validate( | |
model: Union[SimpleCNN, SimpleResNet, ComplexResNet, ComplexResNetV2], | |
criterion: CrossEntropyLoss, | |
validation_loader: DataLoader, | |
device: torch.device, | |
) -> Tuple[float, float]: | |
""" | |
--------- | |
Arguments | |
--------- | |
model: object | |
an object of type torch model | |
criterion: object | |
an object of type torch criterion function | |
validation_loader: object | |
an object of type torch dataloader | |
device: object | |
an object of type torch device | |
------- | |
Returns | |
------- | |
(validation_loss, validation_acc) : tuple | |
a tuple of validation loss and validation accuracy | |
""" | |
model.to(device) | |
model.eval() | |
validation_running_loss = 0.0 | |
validation_running_correct = 0 | |
num_validation_samples = len(validation_loader.dataset) | |
num_validation_batches = len(validation_loader) | |
with torch.no_grad(): | |
for data, label in validation_loader: | |
data = data.to(device, dtype=torch.float) | |
label = label.to(device, dtype=torch.long) | |
logits = model(data) | |
loss = criterion(logits, label) | |
validation_running_loss += loss.item() | |
pred_label = torch.argmax(logits, dim=1) | |
# logging.info(logits, pred_label, label) | |
validation_running_correct += (pred_label == label).sum().item() | |
validation_loss = validation_running_loss / num_validation_batches | |
validation_acc = 100.0 * validation_running_correct / num_validation_samples | |
return validation_loss, validation_acc | |
def train_classifier(ARGS: argparse.Namespace) -> None: | |
logging.basicConfig(level=logging.INFO) | |
wandb.login() | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
logging.info("CUDA device not found, so exiting....") | |
sys.exit(0) | |
logging.info("Training a CNN model for the Overhead MNIST dataset") | |
train_x = [] | |
train_y = [] | |
dir_train = os.path.join(ARGS.dir_dataset, "train") | |
list_sub_dirs = sorted(os.listdir(dir_train)) | |
num_classes = len(list_sub_dirs) | |
for sub_dir_idx in range(num_classes): | |
temp_train_x = os.listdir(os.path.join(dir_train, list_sub_dirs[sub_dir_idx])) | |
temp_train_x = [ | |
os.path.join(list_sub_dirs[sub_dir_idx], f) for f in temp_train_x | |
] | |
temp_train_y = [sub_dir_idx] * len(temp_train_x) | |
train_x = train_x + temp_train_x | |
train_y = train_y + temp_train_y | |
( | |
train_x, | |
validation_x, | |
train_y, | |
validation_y, | |
) = split_dataset(train_x, train_y) | |
num_train_samples = len(train_x) | |
num_validation_samples = len(validation_x) | |
train_loader, validation_loader = get_dataloaders_for_training( | |
train_x, | |
train_y, | |
validation_x, | |
validation_y, | |
dir_images=dir_train, | |
batch_size=ARGS.batch_size, | |
) | |
dir_model = os.path.join(ARGS.dir_model, ARGS.model_type) | |
if not os.path.isdir(dir_model): | |
logging.info(f"Creating directory: {dir_model}") | |
os.makedirs(dir_model) | |
logging.info( | |
f"Num train samples: {num_train_samples}, num validation samples: {num_validation_samples}" | |
) | |
logging.info(f"Num classes: {num_classes}, model_type: {ARGS.model_type}") | |
if ARGS.model_type == "simple_cnn": | |
model = SimpleCNN(num_classes=num_classes) | |
elif ARGS.model_type == "simple_resnet": | |
model = SimpleResNet(num_classes=num_classes) | |
elif ARGS.model_type == "medium_simple_resnet": | |
model = SimpleResNet( | |
list_num_res_units_per_block=[4, 4], num_classes=num_classes | |
) | |
elif ARGS.model_type == "deep_simple_resnet": | |
model = SimpleResNet( | |
list_num_res_units_per_block=[6, 6], num_classes=num_classes | |
) | |
elif ARGS.model_type == "complex_resnet": | |
model = ComplexResNet( | |
list_num_res_units_per_block=[4, 4, 4], num_classes=num_classes | |
) | |
elif ARGS.model_type == "complex_resnet_v2": | |
model = ComplexResNetV2( | |
list_num_res_units_per_block=[4, 4, 4], num_classes=num_classes | |
) | |
else: | |
logging.info(f"Unidentified option for arg (model_type): {ARGS.model_type}") | |
model.to(device) | |
if ARGS.optimizer_type == "sgd": | |
optimizer = torch.optim.SGD( | |
model.parameters(), | |
lr=ARGS.learning_rate, | |
weight_decay=ARGS.weight_decay, | |
momentum=0.9, | |
) | |
else: | |
optimizer = torch.optim.AdamW( | |
model.parameters(), lr=ARGS.learning_rate, weight_decay=ARGS.weight_decay | |
) | |
if ARGS.lr_scheduler_type == "poly": | |
lr_scheduler = PolynomialLR( | |
optimizer, | |
ARGS.num_epochs + 1, | |
power=0.75, | |
) | |
criterion = CrossEntropyLoss() | |
config = { | |
"dataset": "Overhead-MNIST", | |
"architecture": "CNN", | |
"model_type": ARGS.model_type, | |
"optimizer": ARGS.optimizer_type, | |
"lr_scheduler": ARGS.lr_scheduler_type, | |
"learning_rate": ARGS.learning_rate, | |
"num_epochs": ARGS.num_epochs, | |
"batch_size": ARGS.batch_size, | |
"weight_decay": ARGS.weight_decay, | |
} | |
best_validation_acc = 0 | |
logging.info( | |
f"Training the Overhead MNIST image classification model started, model_type: {ARGS.model_type}" | |
) | |
with wandb.init(project="overhead-mnist-model", config=config): | |
for epoch in range(1, ARGS.num_epochs + 1): | |
time_start = time.time() | |
train_loss, train_acc = train( | |
model, optimizer, criterion, train_loader, device | |
) | |
validation_loss, validation_acc = validate( | |
model, criterion, validation_loader, device | |
) | |
time_end = time.time() | |
logging.info( | |
f"Epoch: {epoch}/{ARGS.num_epochs}, time: {time_end-time_start:.4f} sec." | |
) | |
logging.info( | |
f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}" | |
) | |
logging.info( | |
f"Validation Loss: {validation_loss:.4f}, Validation Accuracy: {validation_acc:.4f}\n" | |
) | |
wandb.log( | |
{ | |
"train_loss": train_loss, | |
"train_acc": train_acc, | |
"validation_loss": validation_loss, | |
"validation_acc": validation_acc, | |
}, | |
step=epoch, | |
) | |
if validation_acc >= best_validation_acc: | |
best_validation_acc = validation_acc | |
torch.save( | |
model.state_dict(), os.path.join(dir_model, f"{ARGS.model_type}.pt") | |
) | |
wandb.save(os.path.join(dir_model, f"{ARGS.model_type}.pt")) | |
if ARGS.lr_scheduler_type == "poly": | |
lr_scheduler.step() | |
logging.info("Training the Overhead MNIST image classification model complete!!!!") | |
return | |
def main() -> None: | |
learning_rate = 1e-3 | |
weight_decay = 5e-6 | |
batch_size = 64 | |
num_epochs = 100 | |
model_type = "simple_cnn" | |
dir_dataset = "/home/abhishek/Desktop/datasets/overhead_mnist/version2" | |
dir_model = "trained_models" | |
lr_scheduler_type = "poly" | |
optimizer_type = "adam" | |
parser = argparse.ArgumentParser( | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter | |
) | |
parser.add_argument( | |
"--learning_rate", | |
default=learning_rate, | |
type=float, | |
help="learning rate to use for training", | |
) | |
parser.add_argument( | |
"--weight_decay", | |
default=weight_decay, | |
type=float, | |
help="weight decay to use for training", | |
) | |
parser.add_argument( | |
"--batch_size", | |
default=batch_size, | |
type=int, | |
help="batch size to use for training", | |
) | |
parser.add_argument( | |
"--num_epochs", | |
default=num_epochs, | |
type=int, | |
help="num epochs to train the model", | |
) | |
parser.add_argument( | |
"--dir_dataset", | |
default=dir_dataset, | |
type=str, | |
help="full directory path to dataset containing images", | |
) | |
parser.add_argument( | |
"--dir_model", | |
default=dir_model, | |
type=str, | |
help="full directory path where model needs to be saved", | |
) | |
parser.add_argument( | |
"--model_type", | |
default=model_type, | |
type=str, | |
choices=[ | |
"simple_cnn", | |
"simple_resnet", | |
"medium_simple_resnet", | |
"deep_simple_resnet", | |
"complex_resnet", | |
"complex_resnet_v2", | |
], | |
help="model type to be trained", | |
) | |
parser.add_argument( | |
"--lr_scheduler_type", | |
default=lr_scheduler_type, | |
type=str, | |
choices=["none", "poly"], | |
help="learning rate scheduler to be used for training", | |
) | |
parser.add_argument( | |
"--optimizer_type", | |
default=optimizer_type, | |
type=str, | |
choices=["adam", "sgd"], | |
help="optimizer to be used for training", | |
) | |
ARGS, unparsed = parser.parse_known_args() | |
train_classifier(ARGS) | |
return | |
if __name__ == "__main__": | |
main() | |