import torch import torch.nn as nn from torch.optim import SGD, lr_scheduler from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader, random_split from torchvision.datasets import ImageFolder from model import HNet, ResNet18 import config as CFG from tqdm.auto import tqdm from prettytable import PrettyTable from argparse import ArgumentParser from copy import deepcopy from typing import Dict import time import logging import sys from data import transforms # check is models folder exists (CFG.BASE_PATH / "models").mkdir(exist_ok=True) # Set up logger logging.basicConfig( filename="train.log", format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO, filemode="a", ) best_acc = 0.0 def run_one_epoch( epoch: int, ds_sizes: Dict[str, int], dataloaders: Dict[str, DataLoader], model: nn.Module, optimizer: torch.optim.Optimizer, loss: nn.Module, scheduler: torch.optim.lr_scheduler, ): """ Run one complete train-val loop Parameter --------- ds_sizes: Dictionary containing dataset sizes dataloaders: Dictionary containing dataloaders model: The model optimizer: The optimizer loss: The loss Returns ------- metrics: Dictionary containing Train(loss/accuracy) & Validation(loss/accuracy) """ global best_acc metrics = {} for phase in ["train", "val"]: logging.info(f"{phase.upper()} phase") if phase == "train": model.train() else: model.eval() avg_loss = 0 running_corrects = 0 for batch_idx, (images, labels) in enumerate( tqdm(dataloaders[phase], total=len(dataloaders[phase])) ): images = images.to(CFG.DEVICE) labels = labels.to(CFG.DEVICE) # Zero the gradients optimizer.zero_grad() # Track history if in phase == "train" with torch.set_grad_enabled(phase == "train"): outputs = model(images) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) if phase == "train": loss.backward() optimizer.step() avg_loss += loss.item() * images.size(0) running_corrects += torch.sum(preds == labels) if batch_idx % CFG.INTERVAL == 0: corrects = torch.sum(preds == labels) logging.info( f"Epoch {epoch} - {phase.upper()} - Batch {batch_idx} - Loss = {round(loss.item(), 3)} | Accuracy = {100 * corrects/CFG.BATCH_SIZE}%" ) epoch_loss = avg_loss / ds_sizes[phase] epoch_acc = running_corrects.double() / ds_sizes[phase] # step the scheduler if phase == "train": scheduler.step() # save best model wts if phase == "val" and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = deepcopy(model.state_dict()) torch.save(best_model_wts, CFG.BEST_MODEL_PATH) # Metrics tracking if phase == "train": metrics["train_loss"] = round(epoch_loss, 3) metrics["train_acc"] = round(100 * epoch_acc.item(), 3) else: metrics["val_loss"] = round(epoch_loss, 3) metrics["val_acc"] = round(100 * epoch_acc.item(), 3) return metrics def train(dataloaders, ds_sizes, model, optimizer, criterion, scheduler): for epoch in range(CFG.EPOCHS): start = time.time() metrics = run_one_epoch( epoch=epoch, ds_sizes=ds_sizes, dataloaders=dataloaders, model=model, optimizer=optimizer, loss=criterion, scheduler=scheduler, ) end = time.time() - start print(f"Epoch completed in: {round(end/60, 3)} mins") table.add_row( row=[ epoch + 1, metrics["train_loss"], metrics["train_acc"], metrics["val_loss"], metrics["val_acc"], ] ) print(table) # Write results to file with open("results.txt", "w") as f: results = table.get_string() f.write(results) if __name__ == "__main__": TRAIN_PATH, TEST_PATH, BEST_MODEL = "", "", "" parser = ArgumentParser(description="Train model for Hindi Character Recognition") parser.add_argument( "--epochs", type=int, help="number of epochs", default=CFG.EPOCHS ) parser.add_argument("--lr", type=float, help="learning rate", default=CFG.LR) parser.add_argument( "--model_type", type=str, help="Type of model (vyanjan/digit)", default="vyanjan", ) args = parser.parse_args() if args.model_type == "digit": model = HNet(num_classes=10) logging.info("Initialized Digit model") TRAIN_PATH = CFG.TRAIN_DIGIT_PATH CFG.BEST_MODEL_PATH = CFG.BEST_MODEL_DIGIT else: model = HNet(num_classes=36) logging.info("Initialized Vyanjan model") TRAIN_PATH = CFG.TRAIN_VYANJAN_PATH CFG.BEST_MODEL_PATH = CFG.BEST_MODEL_VYANJAN # creating the datasets train_ds = ImageFolder(root=TRAIN_PATH, transform=transforms["train"]) # Train/val splitting lengths = [int(len(train_ds) * 0.8), len(train_ds) - int(len(train_ds) * 0.8)] train_ds, val_ds = random_split(dataset=train_ds, lengths=lengths) # creating the dataloaders train_dl = DataLoader(dataset=train_ds, batch_size=CFG.BATCH_SIZE, shuffle=True) val_dl = DataLoader(dataset=val_ds, batch_size=CFG.BATCH_SIZE) if len(sys.argv) > 1: CFG.EPOCHS = args.epochs CFG.LR = args.lr # table table = PrettyTable( field_names=["Epoch", "Train Loss", "Train Acc", "Val Loss", "Val Acc"] ) # the model model.to(CFG.DEVICE) # Setting up optimizer and loss optimizer = SGD(model.parameters(), lr=CFG.LR) criterion = CrossEntropyLoss() scheduler = lr_scheduler.CyclicLR( optimizer=optimizer, base_lr=1e-5, max_lr=0.1, verbose=True ) dataloaders = {"train": train_dl, "val": val_dl} ds_sizes = {"train": len(train_ds), "val": len(val_ds)} detail = f""" Training details: ------------------------ Model: {model._get_name()} Model Type: {args.model_type} Epochs: {CFG.EPOCHS} Optimizer: {type(optimizer).__name__} Loss: {criterion._get_name()} Learning Rate: {CFG.LR} Learning Rate Scheduler: {scheduler.__str__()} Batch Size: {CFG.BATCH_SIZE} Logging Interval: {CFG.INTERVAL} batches Train-dataset samples: {len(train_ds)} Validation-dataset samples: {len(val_ds)} ------------------------- """ print(detail) logging.info(detail) start_train = time.time() train( dataloaders=dataloaders, ds_sizes=ds_sizes, model=model, optimizer=optimizer, criterion=criterion, scheduler=scheduler, ) end_train = time.time() - start_train print(f"Training completed in: {round(end_train/60, 3)} mins")