''' This module contains utility functions for training model ''' # Handling files import os # Datetime from datetime import datetime # Plotting confusion matrix from plot_utils import plot_confusion_matrix, plot_training_progress # Torch import torch from torchmetrics.classification import F1Score # Progression bar from tqdm import tqdm def evaluate(model, loss_func, val_loader, device, cm=False): ''' Evaluate a model's performance through loss, accuracy, weighted F1 score Args: model : the model used loss_func (torch.nn.Module) : the loss function used val_loader (torch.utils.data.DataLoader): the loader used to load data device (str) : the device used cm (bool) : decide to plot confusion matrix and top-k misclassified classes table or not, default: False Return: val_loss (float): loss accuracy (float): accuracy f1_score (float): weighted F1 score ''' # Set up model.to(device) val_loss = 0.0 f1 = F1Score( task="multiclass", num_classes=len(val_loader.dataset.classes), average="weighted" ).to(device) all_preds = [] all_labels = [] # Evaluate model.eval() with torch.no_grad(): for imgs, labels, _ in tqdm(val_loader, desc="Evaluating"): imgs = imgs.to(device) labels = labels.to(device) outputs = model(imgs).logits loss = loss_func(outputs, labels) preds = torch.argmax(outputs, dim=1) val_loss += loss.item() all_preds.append(preds) all_labels.append(labels) # Concatenate predictions and labels all_preds = torch.cat(all_preds) all_labels = torch.cat(all_labels) accuracy = (all_preds==all_labels).sum().item() / len(all_labels) f1_score = f1(all_preds, all_labels).item() # Plot confusion matrix if required if cm: current_time = datetime.now().strftime("%Y%m%dT%H%M%S") plot_confusion_matrix( y_true=all_labels.cpu(), y_pred=all_preds.cpu(), display_labels=val_loader.dataset.classes, save_path=f"/dinosaur_project/test_results/{current_time}_evaluation_result.png" ) return val_loss, accuracy, f1_score def train_epoch( model, loss_func, optimizer, train_loader, val_loader, device, scheduler=None, mix_augment=None ): ''' Train a model for one epoch Args: model : the model used loss_func (torch.nn.Module) : the loss function used optimizer (torch.optim.Optimizer) : the optimizer used train_loader (torch.utils.data.DataLoader) : the loader used to load training data val_loader (torch.utils.data.DataLoader) : the loader used to load validation data device (str) : the device used scheduler (torch.optim.lr_scheduler.LRScheduler): learning rate scheduler, default: None mix_augment : mixup/cutmix augmentation, default: None Return: avg_train_loss (float): average training loss avg_val_loss (float) : average validation loss accuracy (float) : accuracy f1_score (float) : weighted F1 score lr (float) : learning rate ''' # Set up model.to(device) train_loss = 0.0 # Train model.train() for imgs, labels, _ in tqdm(train_loader, desc="Training"): imgs = imgs.to(device) labels = labels.to(device) # Use mixup/cutmix augmentation if required if mix_augment: imgs, labels = mix_augment(imgs, labels) optimizer.zero_grad() outputs = model(imgs).logits loss = loss_func(outputs, labels) loss.backward() optimizer.step() train_loss += loss.item() # Evaluate val_loss, accuracy, f1_score = evaluate( model, loss_func, val_loader, device, cm=False ) # Current learning rate lr = optimizer.param_groups[0]["lr"] # Update scheduler if required if scheduler: scheduler.step() # Calculate average train, validation loss avg_train_loss = train_loss / len(train_loader.dataset) avg_val_loss = val_loss / len(val_loader.dataset) print( f"Average Train Loss: {avg_train_loss:.4f}", f"Average Validation Loss: {avg_val_loss:.4f}", f"Accuracy: {accuracy:.4f}", f"Weighted F1: {f1_score:.4f}", sep=" | " ) return avg_train_loss, avg_val_loss, accuracy, f1_score, lr def train( model, n_epochs, loss_func, optimizer, train_loader, val_loader, device, early_stopping_patience=3, scheduler=None, mix_augment=None, model_dir="/dinosaur_project/model", train_plot_dir="/dinosaur_project/train_process" ): ''' Train a model for a number of epochs Args: model, loss_func, optimizer, train_loader, : same as in train_epoch function val_loader, device, scheduler, mix_augment n_epochs (int) : number of epochs to train early_stopping_patience (int) : number of epoch to wait before trigger early stopping model_dir (str) : directory to save model checkpoints, default: /model train_plot_dir (str) : directory to save training process plot, default: /train_process_plot Return: None ''' # Set up current_time = datetime.now().strftime("%Y%m%dT%H%M%S") best_model_path = os.path.join(model_dir, f"{current_time}_best_model") train_process_plot_path = os.path.join( train_plot_dir, f"{current_time}_train_process.png" ) avg_train_losses = [] avg_val_losses = [] accuracy_scores = [] f1_scores = [] learning_rates = [] best_f1 = 0.0 best_f1_epoch = 1 early_stopping_cnt = 0 # Train epochs for i in range(n_epochs): print(f"Epoch {i+1}:") avg_train_loss, avg_val_loss, accuracy, f1_score, lr = train_epoch( model, loss_func, optimizer, train_loader, val_loader, device, scheduler, mix_augment ) avg_train_losses.append(avg_train_loss) avg_val_losses.append(avg_val_loss) accuracy_scores.append(accuracy) f1_scores.append(f1_score) learning_rates.append(lr) # Check early stopping and save best model if f1_score <= best_f1: early_stopping_cnt += 1 else: best_f1 = f1_score best_f1_epoch = i+1 early_stopping_cnt = 0 model.save_pretrained(best_model_path) if early_stopping_cnt == early_stopping_patience: print( f"Early stopping triggered. Best weighted F1: {best_f1:.4f},", f"achieved on epoch {best_f1_epoch}" ) break # Plot training process plot_training_progress( avg_training_losses=avg_train_losses, avg_val_losses=avg_val_losses, accuracy_scores=accuracy_scores, f1_scores=f1_scores, lr_changes=learning_rates, show=False, save_path=train_process_plot_path ) print( f"Best model is saved to {best_model_path}\n", f"Training Process plot is saved to {train_process_plot_path}" )