realfake / realfake /callbacks.py
devforfu
Fine-tuning support
12babad
from typing import Iterable
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks import BaseFinetuning, Callback
from pytorch_lightning.utilities import rank_zero_info
class ConsoleLogger(Callback):
def __init__(self):
super().__init__()
self._reset()
def get_history(self) -> list:
return list(self._history)
def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self._reset()
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
[lr] = trainer.lr_scheduler_configs[0].scheduler.get_last_lr() # type: ignore
log = {"epoch": trainer.current_epoch, "lr": lr}
log.update({name: tensor.item() for name, tensor in trainer.logged_metrics.items()})
self._history.append(log)
formatted = []
for key, value in log.items():
if isinstance(value, int):
kv = f"{key}={value:3d}"
elif isinstance(value, float):
kv = f"{key}={value:.4f}"
else:
kv = f"{key}={value}"
formatted.append(kv)
rank_zero_info(" | ".join(formatted))
def _reset(self):
self._history = []
class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
def __init__(self, unfreeze_at_epoch: int):
super().__init__()
self._unfreeze_at_epoch = unfreeze_at_epoch
def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
rank_zero_info("Freezing backbone")
self.freeze(_get_backbone(pl_module.model))
enabled_layers = [
name
for name, child in pl_module.model.named_children()
if all(param.requires_grad for param in child.parameters())
]
rank_zero_info(f"Gradient enabled layers: [{', '.join(enabled_layers)}]")
def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer, opt_idx: int) -> None:
if epoch == self._unfreeze_at_epoch:
rank_zero_info(f"Unfreezing backbone at epoch {epoch}")
self.unfreeze_and_add_param_group(
modules=_get_backbone(pl_module.model),
optimizer=optimizer,
train_bn=True,
)
def _get_backbone(module: pl.LightningModule) -> Iterable[nn.Module]:
for name, child in module.named_children():
if name.startswith("head"):
continue
yield child