|
from typing import Iterable |
|
|
|
import torch.nn as nn |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.callbacks import BaseFinetuning, Callback |
|
from pytorch_lightning.utilities import rank_zero_info |
|
|
|
|
|
class ConsoleLogger(Callback): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self._reset() |
|
|
|
def get_history(self) -> list: |
|
return list(self._history) |
|
|
|
def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: |
|
self._reset() |
|
|
|
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: |
|
[lr] = trainer.lr_scheduler_configs[0].scheduler.get_last_lr() |
|
log = {"epoch": trainer.current_epoch, "lr": lr} |
|
log.update({name: tensor.item() for name, tensor in trainer.logged_metrics.items()}) |
|
self._history.append(log) |
|
formatted = [] |
|
for key, value in log.items(): |
|
if isinstance(value, int): |
|
kv = f"{key}={value:3d}" |
|
elif isinstance(value, float): |
|
kv = f"{key}={value:.4f}" |
|
else: |
|
kv = f"{key}={value}" |
|
formatted.append(kv) |
|
rank_zero_info(" | ".join(formatted)) |
|
|
|
def _reset(self): |
|
self._history = [] |
|
|
|
|
|
class FeatureExtractorFreezeUnfreeze(BaseFinetuning): |
|
|
|
def __init__(self, unfreeze_at_epoch: int): |
|
super().__init__() |
|
self._unfreeze_at_epoch = unfreeze_at_epoch |
|
|
|
def freeze_before_training(self, pl_module: pl.LightningModule) -> None: |
|
rank_zero_info("Freezing backbone") |
|
self.freeze(_get_backbone(pl_module.model)) |
|
enabled_layers = [ |
|
name |
|
for name, child in pl_module.model.named_children() |
|
if all(param.requires_grad for param in child.parameters()) |
|
] |
|
rank_zero_info(f"Gradient enabled layers: [{', '.join(enabled_layers)}]") |
|
|
|
def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer, opt_idx: int) -> None: |
|
if epoch == self._unfreeze_at_epoch: |
|
rank_zero_info(f"Unfreezing backbone at epoch {epoch}") |
|
self.unfreeze_and_add_param_group( |
|
modules=_get_backbone(pl_module.model), |
|
optimizer=optimizer, |
|
train_bn=True, |
|
) |
|
|
|
|
|
def _get_backbone(module: pl.LightningModule) -> Iterable[nn.Module]: |
|
for name, child in module.named_children(): |
|
if name.startswith("head"): |
|
continue |
|
yield child |
|
|