File size: 2,512 Bytes
ea847ad 12babad ea847ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from typing import Iterable
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks import BaseFinetuning, Callback
from pytorch_lightning.utilities import rank_zero_info
class ConsoleLogger(Callback):
def __init__(self):
super().__init__()
self._reset()
def get_history(self) -> list:
return list(self._history)
def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self._reset()
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
[lr] = trainer.lr_scheduler_configs[0].scheduler.get_last_lr() # type: ignore
log = {"epoch": trainer.current_epoch, "lr": lr}
log.update({name: tensor.item() for name, tensor in trainer.logged_metrics.items()})
self._history.append(log)
formatted = []
for key, value in log.items():
if isinstance(value, int):
kv = f"{key}={value:3d}"
elif isinstance(value, float):
kv = f"{key}={value:.4f}"
else:
kv = f"{key}={value}"
formatted.append(kv)
rank_zero_info(" | ".join(formatted))
def _reset(self):
self._history = []
class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
def __init__(self, unfreeze_at_epoch: int):
super().__init__()
self._unfreeze_at_epoch = unfreeze_at_epoch
def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
rank_zero_info("Freezing backbone")
self.freeze(_get_backbone(pl_module.model))
enabled_layers = [
name
for name, child in pl_module.model.named_children()
if all(param.requires_grad for param in child.parameters())
]
rank_zero_info(f"Gradient enabled layers: [{', '.join(enabled_layers)}]")
def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer, opt_idx: int) -> None:
if epoch == self._unfreeze_at_epoch:
rank_zero_info(f"Unfreezing backbone at epoch {epoch}")
self.unfreeze_and_add_param_group(
modules=_get_backbone(pl_module.model),
optimizer=optimizer,
train_bn=True,
)
def _get_backbone(module: pl.LightningModule) -> Iterable[nn.Module]:
for name, child in module.named_children():
if name.startswith("head"):
continue
yield child
|