from typing import Iterable import torch.nn as nn import pytorch_lightning as pl from pytorch_lightning.callbacks import BaseFinetuning, Callback from pytorch_lightning.utilities import rank_zero_info class ConsoleLogger(Callback): def __init__(self): super().__init__() self._reset() def get_history(self) -> list: return list(self._history) def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self._reset() def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: [lr] = trainer.lr_scheduler_configs[0].scheduler.get_last_lr() # type: ignore log = {"epoch": trainer.current_epoch, "lr": lr} log.update({name: tensor.item() for name, tensor in trainer.logged_metrics.items()}) self._history.append(log) formatted = [] for key, value in log.items(): if isinstance(value, int): kv = f"{key}={value:3d}" elif isinstance(value, float): kv = f"{key}={value:.4f}" else: kv = f"{key}={value}" formatted.append(kv) rank_zero_info(" | ".join(formatted)) def _reset(self): self._history = [] class FeatureExtractorFreezeUnfreeze(BaseFinetuning): def __init__(self, unfreeze_at_epoch: int): super().__init__() self._unfreeze_at_epoch = unfreeze_at_epoch def freeze_before_training(self, pl_module: pl.LightningModule) -> None: rank_zero_info("Freezing backbone") self.freeze(_get_backbone(pl_module.model)) def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer, opt_idx: int) -> None: if epoch == self._unfreeze_at_epoch: rank_zero_info(f"Unfreezing backbone at epoch {epoch}") self.unfreeze_and_add_param_group( modules=_get_backbone(pl_module.model), optimizer=optimizer, train_bn=True, ) def _get_backbone(module: pl.LightningModule) -> Iterable[nn.Module]: for name, child in module.named_children(): if name.startswith("head"): continue yield child