Spaces:
Sleeping
Sleeping
File size: 10,564 Bytes
aea73e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
# -*- coding: utf-8 -*-
# Base Trainer Class
#
# @ Fabian Hörst, fabian.hoerst@uk-essen.de
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen
import logging
from abc import abstractmethod
from typing import Tuple, Union
import torch
import torch.nn as nn
import wandb
from base_ml.base_early_stopping import EarlyStopping
from pathlib import Path
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from utils.tools import flatten_dict
class BaseTrainer:
"""
Base class for all trainers with important ML components
Args:
model (nn.Module): Model that should be trained
loss_fn (_Loss): Loss function
optimizer (Optimizer): Optimizer
scheduler (_LRScheduler): Learning rate scheduler
device (str): Cuda device to use, e.g., cuda:0.
logger (logging.Logger): Logger module
logdir (Union[Path, str]): Logging directory
experiment_config (dict): Configuration of this experiment
early_stopping (EarlyStopping, optional): Early Stopping Class. Defaults to None.
accum_iter (int, optional): Accumulation steps for gradient accumulation.
Provide a number greater than 1 for activating gradient accumulation. Defaults to 1.
mixed_precision (bool, optional): If mixed-precision should be used. Defaults to False.
log_images (bool, optional): If images should be logged to WandB. Defaults to False.
"""
def __init__(
self,
model: nn.Module,
loss_fn: _Loss,
optimizer: Optimizer,
scheduler: _LRScheduler,
device: str,
logger: logging.Logger,
logdir: Union[Path, str],
experiment_config: dict,
early_stopping: EarlyStopping = None,
accum_iter: int = 1,
mixed_precision: bool = False,
log_images: bool = False,
#model_ema: bool = True,
) -> None:
self.model = model
self.loss_fn = loss_fn
self.optimizer = optimizer
self.scheduler = scheduler
self.device = device
self.logger = logger
self.logdir = Path(logdir)
self.early_stopping = early_stopping
self.accum_iter = accum_iter
self.start_epoch = 0
self.experiment_config = experiment_config
self.log_images = log_images
self.mixed_precision = mixed_precision
if self.mixed_precision:
self.scaler = torch.cuda.amp.GradScaler(enabled=True)
else:
self.scaler = None
@abstractmethod
def train_epoch(
self, epoch: int, train_loader: DataLoader, **kwargs
) -> Tuple[dict, dict]:
"""Training logic for a training epoch
Args:
epoch (int): Current epoch number
train_loader (DataLoader): Train dataloader
Raises:
NotImplementedError: Needs to be implemented
Returns:
Tuple[dict, dict]: wandb logging dictionaries
* Scalar metrics
* Image metrics
"""
raise NotImplementedError
@abstractmethod
def validation_epoch(
self, epoch: int, val_dataloader: DataLoader
) -> Tuple[dict, dict, float]:
"""Training logic for an validation epoch
Args:
epoch (int): Current epoch number
val_dataloader (DataLoader): Validation dataloader
Raises:
NotImplementedError: Needs to be implemented
Returns:
Tuple[dict, dict, float]: wandb logging dictionaries and early_stopping_metric
* Scalar metrics
* Image metrics
* Early Stopping metric as float
"""
raise NotImplementedError
@abstractmethod
def train_step(self, batch: object, batch_idx: int, num_batches: int):
"""Training logic for one training batch
Args:
batch (object): A training batch
batch_idx (int): Current batch index
num_batches (int): Maximum number of batches
Raises:
NotImplementedError: Needs to be implemented
"""
raise NotImplementedError
@abstractmethod
def validation_step(self, batch, batch_idx: int):
"""Training logic for one validation batch
Args:
batch (object): A training batch
batch_idx (int): Current batch index
Raises:
NotImplementedError: Needs to be implemented
"""
def fit(
self,
epochs: int,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
metric_init: dict = None,
eval_every: int = 1,
**kwargs,
):
"""Fitting function to start training and validation of the trainer
Args:
epochs (int): Number of epochs the network should be training
train_dataloader (DataLoader): Dataloader with training data
val_dataloader (DataLoader): Dataloader with validation data
metric_init (dict, optional): Initialization dictionary with scalar metrics that should be initialized for startup.
This is just import for logging with wandb if you want to have the plots properly scaled.
The data in the the metric dictionary is used as values for epoch 0 (before training has startetd).
If not provided, step 0 (epoch 0) is not logged. Should have the same scalar keys as training and validation epochs report.
For more information, you should have a look into the train_epoch and val_epoch methods where the wandb logging dicts are assembled.
Defaults to None.
eval_every (int, optional): How often the network should be evaluated (after how many epochs). Defaults to 1.
**kwargs
"""
self.logger.info(f"Starting training, total number of epochs: {epochs}")
if metric_init is not None and self.start_epoch == 0:
wandb.log(metric_init, step=0)
for epoch in range(self.start_epoch, epochs):
# training epoch
#train_sampler.set_epoch(epoch) # for distributed training
self.logger.info(f"Epoch: {epoch+1}/{epochs}")
train_scalar_metrics, train_image_metrics = self.train_epoch(
epoch, train_dataloader, **kwargs
)
wandb.log(train_scalar_metrics, step=epoch + 1)
if self.log_images:
wandb.log(train_image_metrics, step=epoch + 1)
if epoch >=95 and ((epoch + 1)) % eval_every == 0:
# validation epoch
(
val_scalar_metrics,
val_image_metrics,
early_stopping_metric,
) = self.validation_epoch(epoch, val_dataloader)
wandb.log(val_scalar_metrics, step=epoch + 1)
if self.log_images:
wandb.log(val_image_metrics, step=epoch + 1)
#self.save_checkpoint(epoch, f"checkpoint_{epoch}.pth")
# log learning rate
curr_lr = self.optimizer.param_groups[0]["lr"]
wandb.log(
{
"Learning-Rate/Learning-Rate": curr_lr,
},
step=epoch + 1,
)
if epoch >=95 and ((epoch + 1)) % eval_every == 0:
# early stopping
if self.early_stopping is not None:
best_model = self.early_stopping(early_stopping_metric, epoch)
if best_model:
self.logger.info("New best model - save checkpoint")
self.save_checkpoint(epoch, "model_best.pth")
elif self.early_stopping.early_stop:
self.logger.info("Performing early stopping!")
break
self.save_checkpoint(epoch, "latest_checkpoint.pth")
# scheduling
if type(self.scheduler) == torch.optim.lr_scheduler.ReduceLROnPlateau:
self.scheduler.step(float(val_scalar_metrics["Loss/Validation"]))
else:
self.scheduler.step()
new_lr = self.optimizer.param_groups[0]["lr"]
self.logger.debug(f"Old lr: {curr_lr:.6f} - New lr: {new_lr:.6f}")
def save_checkpoint(self, epoch: int, checkpoint_name: str):
if self.early_stopping is None:
best_metric = None
best_epoch = None
else:
best_metric = self.early_stopping.best_metric
best_epoch = self.early_stopping.best_epoch
arch = type(self.model).__name__
state = {
"arch": arch,
"epoch": epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"best_metric": best_metric,
"best_epoch": best_epoch,
"config": flatten_dict(wandb.config),
"wandb_id": wandb.run.id,
"logdir": str(self.logdir.resolve()),
"run_name": str(Path(self.logdir).name),
"scaler_state_dict": self.scaler.state_dict()
if self.scaler is not None
else None,
}
checkpoint_dir = self.logdir / "checkpoints"
checkpoint_dir.mkdir(exist_ok=True, parents=True)
filename = str(checkpoint_dir / checkpoint_name)
torch.save(state, filename)
def resume_checkpoint(self, checkpoint):
self.logger.info("Loading checkpoint")
self.logger.info("Loading Model")
self.model.load_state_dict(checkpoint["model_state_dict"])
self.logger.info("Loading Optimizer state dict")
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
if self.early_stopping is not None:
self.early_stopping.best_metric = checkpoint["best_metric"]
self.early_stopping.best_epoch = checkpoint["best_epoch"]
if self.scaler is not None:
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
self.logger.info(f"Checkpoint epoch: {int(checkpoint['epoch'])}")
self.start_epoch = int(checkpoint["epoch"])
self.logger.info(f"Next epoch is: {self.start_epoch + 1}")
|