File size: 2,512 Bytes
ea847ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12babad
 
 
 
 
 
ea847ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from typing import Iterable

import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks import BaseFinetuning, Callback
from pytorch_lightning.utilities import rank_zero_info


class ConsoleLogger(Callback):

    def __init__(self):
        super().__init__()
        self._reset()

    def get_history(self) -> list:
        return list(self._history)

    def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        self._reset()

    def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        [lr] = trainer.lr_scheduler_configs[0].scheduler.get_last_lr()  # type: ignore
        log = {"epoch": trainer.current_epoch, "lr": lr}
        log.update({name: tensor.item() for name, tensor in trainer.logged_metrics.items()})
        self._history.append(log)
        formatted = []
        for key, value in log.items():
            if isinstance(value, int):
                kv = f"{key}={value:3d}"
            elif isinstance(value, float):
                kv = f"{key}={value:.4f}"
            else:
                kv = f"{key}={value}"
            formatted.append(kv)
        rank_zero_info(" | ".join(formatted))

    def _reset(self):
        self._history = []


class FeatureExtractorFreezeUnfreeze(BaseFinetuning):

    def __init__(self, unfreeze_at_epoch: int):
        super().__init__()
        self._unfreeze_at_epoch = unfreeze_at_epoch
    
    def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
        rank_zero_info("Freezing backbone")
        self.freeze(_get_backbone(pl_module.model))
        enabled_layers = [
            name
            for name, child in pl_module.model.named_children()
            if all(param.requires_grad for param in child.parameters())
        ]
        rank_zero_info(f"Gradient enabled layers: [{', '.join(enabled_layers)}]")

    def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer, opt_idx: int) -> None:
        if epoch == self._unfreeze_at_epoch:
            rank_zero_info(f"Unfreezing backbone at epoch {epoch}")
            self.unfreeze_and_add_param_group(
                modules=_get_backbone(pl_module.model),
                optimizer=optimizer,
                train_bn=True,
            )


def _get_backbone(module: pl.LightningModule) -> Iterable[nn.Module]:
    for name, child in module.named_children():
        if name.startswith("head"):
            continue
        yield child