|
from abc import ABC, abstractmethod |
|
from typing import Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import wandb |
|
from torch.utils.data import DataLoader |
|
from tqdm.autonotebook import tqdm |
|
|
|
import modeling.loss as loss_module |
|
import modeling.metrics as metrics_module |
|
from modeling.loss import HardDistillationLoss |
|
from modeling.models import freeze, layerwise_lr_decay |
|
from modeling.utils import init_obj |
|
|
|
|
|
class BaseLearner(ABC): |
|
""" |
|
Abstract base class for a learner. |
|
|
|
:param train_dl: DataLoader for training data |
|
:type train_dl: Type[DataLoader] |
|
:param valid_dl: DataLoader for validation data |
|
:type valid_dl: Type[DataLoader] |
|
:param model: Model to be trained |
|
:type model: Type[nn.Module] |
|
:param config: Configuration object |
|
:type config: Any |
|
""" |
|
|
|
def __init__(self, train_dl: DataLoader, valid_dl: DataLoader, model: nn.Module, config): |
|
self.train_dl = train_dl |
|
self.valid_dl = valid_dl |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model = model.to(self.device) |
|
self.config = config |
|
|
|
@abstractmethod |
|
def fit( |
|
self, |
|
): |
|
"""Abstract method for fitting the model.""" |
|
|
|
pass |
|
|
|
@abstractmethod |
|
def _train_epoch( |
|
self, |
|
): |
|
"""Abstract method for training the model for one epoch.""" |
|
pass |
|
|
|
@abstractmethod |
|
def _test_epoch( |
|
self, |
|
): |
|
"""Abstract method for testing the model for one epoch.""" |
|
pass |
|
|
|
|
|
class Learner(BaseLearner): |
|
def __init__(self, train_dl: DataLoader, valid_dl: DataLoader, model: nn.Module, config): |
|
""" |
|
A class that inherits from the BaseLearner class and represents a learner object. |
|
|
|
:param train_dl: DataLoader for training data |
|
:type train_dl: DataLoader |
|
:param valid_dl: DataLoader for validation data |
|
:type valid_dl: DataLoader |
|
:param model: Model to be trained |
|
:type model: nn.Module |
|
:param config: Configuration object |
|
:type config: Any |
|
""" |
|
|
|
super().__init__(train_dl, valid_dl, model, config) |
|
|
|
self.model = torch.nn.DataParallel(module=self.model, device_ids=list(range(config.num_gpus))) |
|
self.loss_fn = init_obj(self.config.loss, loss_module) |
|
params = layerwise_lr_decay(self.config, self.model) |
|
self.optimizer = init_obj(self.config.optimizer, optim, params) |
|
self.scheduler = init_obj( |
|
self.config.scheduler, |
|
optim.lr_scheduler, |
|
self.optimizer, |
|
max_lr=[param["lr"] for param in params], |
|
epochs=self.config.epochs, |
|
steps_per_epoch=int(np.ceil(len(train_dl) / self.config.num_accum)), |
|
) |
|
|
|
self.verbose = self.config.verbose |
|
self.metrics = MetricTracker(self.config.metrics, self.verbose) |
|
self.scaler = torch.cuda.amp.GradScaler() |
|
|
|
self.train_step = 0 |
|
self.test_step = 0 |
|
|
|
def fit(self, model_name: str = "model"): |
|
""" |
|
Method to train the model. |
|
|
|
:param model_name: Name of the model to be saved, defaults to "model" |
|
:type model_name: str, optional |
|
""" |
|
|
|
loop = tqdm(range(self.config.epochs), leave=False) |
|
|
|
for epoch in loop: |
|
train_loss = self._train_epoch() |
|
val_loss = self._test_epoch() |
|
|
|
wandb.log({"train_loss": train_loss, "val_loss": val_loss, "epoch": epoch + 1}) |
|
|
|
if self.verbose: |
|
print(f"| EPOCH: {epoch+1} | train_loss: {train_loss:.3f} | val_loss: {val_loss:.3f} |\n") |
|
self.metrics.display() |
|
|
|
if self.config.save_last_checkpoint: |
|
torch.save(self.model.module.state_dict(), f"{model_name}.pth") |
|
|
|
def _train_epoch(self, distill: bool = False): |
|
""" |
|
Method to perform one epoch of training. |
|
|
|
:param distill: Flag to indicate if knowledge distillation is used, defaults to False |
|
:type distill: bool, optional |
|
:return: Average training loss for the epoch |
|
:rtype: float |
|
""" |
|
|
|
if distill: |
|
print("Distilling knowledge...", flush=True) |
|
|
|
loop = tqdm(self.train_dl, leave=False) |
|
self.model.train() |
|
|
|
num_batches = len(self.train_dl) |
|
train_loss = 0 |
|
|
|
for idx, (xb, yb) in enumerate(loop): |
|
xb = xb.to(self.device) |
|
yb = yb.to(self.device) |
|
|
|
|
|
with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=not distill): |
|
predictions = self.model(xb) |
|
|
|
if distill: |
|
loss = self.KDloss_fn(xb, predictions, yb) |
|
else: |
|
loss = self.loss_fn(predictions, yb) |
|
|
|
loss /= self.config.num_accum |
|
|
|
|
|
self.scaler.scale(loss).backward() |
|
wandb.log({f"lr_param_group_{i}": lr for i, lr in enumerate(self.scheduler.get_last_lr())}) |
|
|
|
if ((idx + 1) % self.config.num_accum == 0) or (idx + 1 == num_batches): |
|
self.scaler.step(self.optimizer) |
|
self.scaler.update() |
|
self.scheduler.step() |
|
self.optimizer.zero_grad() |
|
|
|
|
|
loop.set_postfix(loss=loss.item()) |
|
self.train_step += 1 |
|
wandb.log({"train_loss_per_batch": loss.item(), "train_step": self.train_step}) |
|
train_loss += loss.item() |
|
|
|
if distill: |
|
if ((idx + 1) % 2500 == 0) and not (idx + 1 == num_batches): |
|
val_loss = self._test_epoch() |
|
wandb.log({"val_loss": val_loss}) |
|
self.model.train() |
|
|
|
train_loss /= num_batches |
|
|
|
return train_loss |
|
|
|
def _test_epoch(self): |
|
""" |
|
Method to perform one epoch of validation/testing. |
|
|
|
:return: Average validation/test loss for the epoch |
|
:rtype: float |
|
""" |
|
|
|
loop = tqdm(self.valid_dl, leave=False) |
|
self.model.eval() |
|
|
|
num_batches = len(self.valid_dl) |
|
preds = [] |
|
targets = [] |
|
test_loss = 0 |
|
|
|
with torch.no_grad(): |
|
for xb, yb in loop: |
|
xb, yb = xb.to(self.device), yb.to(self.device) |
|
pred = self.model(xb) |
|
loss = self.loss_fn(pred, yb).item() |
|
self.test_step += 1 |
|
wandb.log({"valid_loss_per_batch": loss, "test_step": self.test_step}) |
|
test_loss += loss |
|
|
|
pred = torch.sigmoid(pred) |
|
preds.extend(pred.cpu().numpy()) |
|
targets.extend(yb.cpu().numpy()) |
|
|
|
preds, targets = np.array(preds), np.array(targets) |
|
self.metrics.update(preds, targets) |
|
test_loss /= num_batches |
|
|
|
return test_loss |
|
|
|
|
|
class KDLearner(Learner): |
|
""" |
|
Knowledge Distillation Learner class for training a student model with knowledge distillation. |
|
|
|
:param train_dl: Train data loader |
|
:type train_dl: DataLoader |
|
:param valid_dl: Validation data loader |
|
:type valid_dl: DataLoader |
|
:param student_model: Student model to be trained |
|
:type student_model: nn.Module |
|
:param teacher: Teacher model for knowledge distillation |
|
:type teacher: nn.Module |
|
:param thresholds: Thresholds for HardDistillationLoss |
|
:type thresholds: List[float] |
|
:param config: Configuration object for training |
|
:type config: Config |
|
""" |
|
|
|
def __init__(self, train_dl, valid_dl, student_model, teacher, thresholds, config): |
|
super().__init__(train_dl, valid_dl, student_model, config) |
|
|
|
self.teacher = nn.DataParallel(freeze(teacher).to(self.device)) |
|
self.KDloss_fn = HardDistillationLoss(self.teacher, self.loss_fn, thresholds, self.device) |
|
self.scaler = torch.cuda.amp.GradScaler(enabled=False) |
|
|
|
def _train_epoch(self): |
|
""" |
|
Method to perform one epoch of training with knowledge distillation. |
|
|
|
:return: Average training loss for the epoch |
|
:rtype: float |
|
""" |
|
|
|
return super()._train_epoch(distill=True) |
|
|
|
|
|
class MetricTracker: |
|
""" |
|
Metric Tracker class for tracking evaluation metrics during model validation. |
|
This class is used to track and display evaluation metrics during model validation. |
|
It keeps track of the results of the provided metric functions for each validation batch, |
|
and logs them to Weights & Biases using wandb.log(). The display() method can be used |
|
to print the tracked metric results, if verbose is set to True during initialization. |
|
|
|
:param metrics: List of metric functions to track |
|
:type metrics: List[Callable] |
|
:param verbose: Flag to indicate whether to print the results or not, defaults to True |
|
:type verbose: bool, optional |
|
""" |
|
|
|
def __init__(self, metrics, verbose: bool = True): |
|
self.metrics_fn = [getattr(metrics_module, metric) for metric in metrics] |
|
self.verbose = verbose |
|
self.result = None |
|
|
|
def update(self, preds, targets): |
|
""" |
|
Update the metric tracker with the latest predictions and targets. |
|
|
|
:param preds: Model predictions |
|
:type preds: torch.Tensor |
|
:param targets: Ground truth targets |
|
:type targets: torch.Tensor |
|
""" |
|
|
|
self.result = {metric.__name__: metric(preds, targets) for metric in self.metrics_fn} |
|
wandb.log(self.result) |
|
|
|
def display(self): |
|
"""Display the tracked metric results.""" |
|
|
|
for k, v in self.result.items(): |
|
print(f"{k}: {v:.2f}") |
|
|
|
|
|
def get_preds(data: DataLoader, model: nn.Module, device: str = "cpu") -> Tuple[np.ndarray, np.ndarray]: |
|
""" |
|
Get predictions and targets from a data loader and a PyTorch model. |
|
|
|
:param data: A PyTorch DataLoader containing the data to predict on. |
|
:type data: torch.utils.data.DataLoader |
|
:param model: A PyTorch model to use for predictions. |
|
:type model: torch.nn.Module |
|
:param device: The device to use for predictions (default is "cpu"). |
|
:type device: str |
|
:raises TypeError: If any of the input arguments is of an incorrect type. |
|
:return: A tuple containing two NumPy arrays: the predictions and the targets. |
|
:rtype: Tuple[numpy.ndarray, numpy.ndarray] |
|
""" |
|
|
|
if not isinstance(data, DataLoader): |
|
raise TypeError("The 'data' argument must be a PyTorch DataLoader.") |
|
if not isinstance(model, nn.Module): |
|
raise TypeError("The 'model' argument must be a PyTorch model.") |
|
if not isinstance(device, str): |
|
raise TypeError("The 'device' argument must be a string.") |
|
|
|
loop = tqdm(data, leave=False) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
preds = [] |
|
targets = [] |
|
|
|
with torch.no_grad(): |
|
for xb, yb in loop: |
|
xb, yb = xb.to(device), yb.to(device) |
|
pred = model(xb) |
|
pred = torch.sigmoid(pred) |
|
preds.extend(pred.cpu().numpy()) |
|
targets.extend(yb.cpu().numpy()) |
|
|
|
preds, targets = np.array(preds), np.array(targets) |
|
|
|
return preds, targets |
|
|