File size: 5,294 Bytes
fd01725 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
from pathlib import Path
import pytorch_lightning as pl
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from torchmetrics import MeanMetric, MetricCollection
from . import logger
from .models import get_model
class AverageKeyMeter(MeanMetric):
def __init__(self, key, *args, **kwargs):
self.key = key
super().__init__(*args, **kwargs)
def update(self, dict):
value = dict[self.key]
value = value[torch.isfinite(value)]
return super().update(value)
class GenericModule(pl.LightningModule):
def __init__(self, cfg):
super().__init__()
name = cfg.model.get("name")
name = "map_perception_net" if name is None else name
self.model = get_model(name)(cfg.model)
self.cfg = cfg
self.save_hyperparameters(cfg)
self.metrics_val = MetricCollection(
self.model.metrics(), prefix="val/")
self.losses_val = None # we do not know the loss keys in advance
def forward(self, batch):
return self.model(batch)
def training_step(self, batch):
pred = self(batch)
losses = self.model.loss(pred, batch)
self.log_dict(
{f"train/loss/{k}": v.mean() for k, v in losses.items()},
prog_bar=True,
rank_zero_only=True,
on_epoch=True,
sync_dist=True
)
return losses["total"].mean()
def validation_step(self, batch, batch_idx):
pred = self(batch)
losses = self.model.loss(pred, batch)
if self.losses_val is None:
self.losses_val = MetricCollection(
{k: AverageKeyMeter(k).to(self.device) for k in losses},
prefix="val/",
postfix="/loss",
)
self.metrics_val(pred, batch)
self.log_dict(self.metrics_val, on_epoch=True)
self.losses_val.update(losses)
self.log_dict(self.losses_val, on_epoch=True)
return pred
def test_step(self, batch, batch_idx):
pred = self(batch)
return pred
def validation_epoch_start(self, batch):
self.losses_val = None
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(), lr=self.cfg.training.lr)
ret = {"optimizer": optimizer}
cfg_scheduler = self.cfg.training.get("lr_scheduler")
if cfg_scheduler is not None:
scheduler_args = cfg_scheduler.get("args", {})
for key in scheduler_args:
if scheduler_args[key] == "$total_epochs":
scheduler_args[key] = int(self.trainer.max_epochs)
scheduler = getattr(torch.optim.lr_scheduler, cfg_scheduler.name)(
optimizer=optimizer, **scheduler_args
)
ret["lr_scheduler"] = {
"scheduler": scheduler,
"interval": "epoch",
"frequency": 1,
"monitor": "loss/total/val",
"strict": True,
"name": "learning_rate",
}
return ret
@classmethod
def load_from_checkpoint(
cls,
checkpoint_path,
map_location=None,
hparams_file=None,
strict=True,
cfg=None,
find_best=False,
):
assert hparams_file is None, "hparams are not supported."
checkpoint = torch.load(
checkpoint_path, map_location=map_location or (
lambda storage, loc: storage)
)
if find_best:
best_score, best_name = None, None
modes = {"min": torch.lt, "max": torch.gt}
for key, state in checkpoint["callbacks"].items():
if not key.startswith("ModelCheckpoint"):
continue
mode = eval(key.replace("ModelCheckpoint", ""))["mode"]
if best_score is None or modes[mode](
state["best_model_score"], best_score
):
best_score = state["best_model_score"]
best_name = Path(state["best_model_path"]).name
logger.info("Loading best checkpoint %s", best_name)
if best_name != checkpoint_path:
return cls.load_from_checkpoint(
Path(checkpoint_path).parent / best_name,
map_location,
hparams_file,
strict,
cfg,
find_best=False,
)
logger.info(
"Using checkpoint %s from epoch %d and step %d.",
checkpoint_path,
checkpoint["epoch"],
checkpoint["global_step"],
)
cfg_ckpt = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]
if list(cfg_ckpt.keys()) == ["cfg"]: # backward compatibility
cfg_ckpt = cfg_ckpt["cfg"]
cfg_ckpt = OmegaConf.create(cfg_ckpt)
if cfg is None:
cfg = {}
if not isinstance(cfg, DictConfig):
cfg = OmegaConf.create(cfg)
with open_dict(cfg_ckpt):
cfg = OmegaConf.merge(cfg_ckpt, cfg)
return pl.core.saving._load_state(cls, checkpoint, strict=strict, cfg=cfg)
|