virtex-redcaps / virtex /utils /checkpointing.py
kdexd's picture
Black + isort, remove unused virtx files.
8d0e872
raw
history blame
6.81 kB
import copy
import pathlib
from typing import Any, Dict, List, Optional
from loguru import logger
import torch
from torch import nn
import virtex.utils.distributed as dist
class CheckpointManager(object):
r"""
A helper class to periodically serialize models and other checkpointable
objects (optimizers, LR schedulers etc., which implement ``state_dict``
method) during training, and optionally record best performing checkpoint
based on an observed metric.
.. note::
For :class:`~torch.nn.parallel.DistributedDataParallel` objects,
``state_dict`` of internal model is serialized.
.. note::
The observed metric for keeping best checkpoint is assumed "higher is
better", flip the sign if otherwise.
Parameters
----------
serialization_dir: str
Path to a directory to save checkpoints.
keep_recent: int, optional (default = 100)
Number of recent ``k`` checkpoints to keep on disk. Older checkpoints
will be removed. Set to a very large value for keeping all checkpoints.
checkpointables: Any
Keyword arguments with any checkpointable objects, for example: model,
optimizer, learning rate scheduler.
Examples
--------
>>> model = torch.nn.Linear(10, 2)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> ckpt_manager = CheckpointManager("/tmp", model=model, optimizer=optimizer)
>>> num_epochs = 20
>>> for epoch in range(num_epochs):
... train(model)
... val_loss = validate(model)
... ckpt_manager.step(- val_loss, epoch)
"""
def __init__(
self,
serialization_dir: str = "/tmp",
keep_recent: int = 200,
**checkpointables: Any,
):
self.serialization_dir = pathlib.Path(serialization_dir)
self.keep_recent = keep_recent
# Shallow copy, keeps references to tensors as original objects.
self.checkpointables = copy.copy(checkpointables)
# Initialize members to hold state dict of best checkpoint and its
# performance.
self._best_metric: float = -1e-12
self._best_ckpt: Dict[str, Any] = {}
# Keep epoch/iteration numbers of recently saved 'k' checkpoints.
self._recent_iterations: List[int] = []
def step(self, iteration: int, metric: Optional[float] = None):
r"""
Serialize checkpoint and update best checkpoint based on metric. Keys
in serialized checkpoint match those in :attr:`checkpointables`.
Parameters
----------
iteration: int
Current training iteration. Will be saved with other checkpointables.
metric: float, optional (default = None)
Observed metric (higher is better) for keeping track of best
checkpoint. If this is ``None``, best chckpoint will not be
recorded/updated.
"""
checkpointable_state_dict: Dict[str, Any] = self._state_dict()
# We also checkpoint current iteration.
checkpointable_state_dict["iteration"] = iteration
# Update the best checkpoint based on metric, if provided.
if metric is not None and metric > self._best_metric:
self._best_metric = metric
self._best_ckpt = copy.copy(checkpointable_state_dict)
# Serialize checkpoint corresponding to current iteration.
torch.save(
checkpointable_state_dict,
self.serialization_dir / f"checkpoint_{iteration}.pth",
)
if self._best_metric != -1e-12:
# Serialize best performing checkpoint observed so far.
torch.save(
self._best_ckpt, self.serialization_dir / "checkpoint_best.pth"
)
# Remove earliest checkpoint if there are more on disk.
self._recent_iterations.append(iteration)
if len(self._recent_iterations) > self.keep_recent:
self.remove_earliest_checkpoint()
def _state_dict(self):
r"""Return a dict containing state dict of all checkpointables."""
__state_dict: Dict[str, Any] = {}
for key in self.checkpointables:
if isinstance(
self.checkpointables[key], nn.parallel.DistributedDataParallel
):
__state_dict[key] = self.checkpointables[key].module.state_dict()
else:
__state_dict[key] = self.checkpointables[key].state_dict()
return __state_dict
def remove_earliest_checkpoint(self):
r"""Remove earliest serialized checkpoint from disk."""
earliest_iteration = self._recent_iterations.pop(0)
(self.serialization_dir / f"checkpoint_{earliest_iteration}.pth").unlink()
def load(self, checkpoint_path: str):
r"""
Load a serialized checkpoint from a path. This method will try to find
each of :attr:`checkpointables` in the file and load its state dict.
Since our checkpointables are held as references, this method does not
return them.
Parameters
----------
checkpoint_path: str
Path to a checkpoint serialized by :meth:`step`.
Returns
-------
int
Iteration corresponding to the loaded checkpoint. Useful for
resuming training. This will be -1 in case of best checkpoint,
or if info does not exist.
"""
# Each process will log a message after loading checkpoint.
rank = dist.get_rank()
logger.info(f"Rank {rank}: Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint.pop("iteration", -1)
# Keep flags of all checkpointables to lo which ones were not loaded.
is_loaded = {key: False for key in self.checkpointables}
# Load each checkpointable from checkpoint.
for key in checkpoint:
if key in self.checkpointables:
logger.info(f"Rank {rank}: Loading {key} from {checkpoint_path}")
if isinstance(
self.checkpointables[key], nn.parallel.DistributedDataParallel
):
self.checkpointables[key].module.load_state_dict(checkpoint[key])
else:
self.checkpointables[key].load_state_dict(checkpoint[key])
is_loaded[key] = True
else:
logger.info(f"Rank {rank}: {key} not found in `checkpointables`.")
not_loaded: List[str] = [key for key in is_loaded if not is_loaded[key]]
if len(not_loaded) > 0:
logger.info(
f"Rank {rank}: Checkpointables not found in file: {not_loaded}"
)
return iteration