Spaces:
Runtime error
Runtime error
File size: 6,810 Bytes
a5f8a35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import copy
import pathlib
from typing import Any, Dict, List, Optional
from loguru import logger
import torch
from torch import nn
import virtex.utils.distributed as dist
class CheckpointManager(object):
r"""
A helper class to periodically serialize models and other checkpointable
objects (optimizers, LR schedulers etc., which implement ``state_dict``
method) during training, and optionally record best performing checkpoint
based on an observed metric.
.. note::
For :class:`~torch.nn.parallel.DistributedDataParallel` objects,
``state_dict`` of internal model is serialized.
.. note::
The observed metric for keeping best checkpoint is assumed "higher is
better", flip the sign if otherwise.
Parameters
----------
serialization_dir: str
Path to a directory to save checkpoints.
keep_recent: int, optional (default = 100)
Number of recent ``k`` checkpoints to keep on disk. Older checkpoints
will be removed. Set to a very large value for keeping all checkpoints.
checkpointables: Any
Keyword arguments with any checkpointable objects, for example: model,
optimizer, learning rate scheduler.
Examples
--------
>>> model = torch.nn.Linear(10, 2)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> ckpt_manager = CheckpointManager("/tmp", model=model, optimizer=optimizer)
>>> num_epochs = 20
>>> for epoch in range(num_epochs):
... train(model)
... val_loss = validate(model)
... ckpt_manager.step(- val_loss, epoch)
"""
def __init__(
self,
serialization_dir: str = "/tmp",
keep_recent: int = 200,
**checkpointables: Any,
):
self.serialization_dir = pathlib.Path(serialization_dir)
self.keep_recent = keep_recent
# Shallow copy, keeps references to tensors as original objects.
self.checkpointables = copy.copy(checkpointables)
# Initialize members to hold state dict of best checkpoint and its
# performance.
self._best_metric: float = -1e-12
self._best_ckpt: Dict[str, Any] = {}
# Keep epoch/iteration numbers of recently saved 'k' checkpoints.
self._recent_iterations: List[int] = []
def step(self, iteration: int, metric: Optional[float] = None):
r"""
Serialize checkpoint and update best checkpoint based on metric. Keys
in serialized checkpoint match those in :attr:`checkpointables`.
Parameters
----------
iteration: int
Current training iteration. Will be saved with other checkpointables.
metric: float, optional (default = None)
Observed metric (higher is better) for keeping track of best
checkpoint. If this is ``None``, best chckpoint will not be
recorded/updated.
"""
checkpointable_state_dict: Dict[str, Any] = self._state_dict()
# We also checkpoint current iteration.
checkpointable_state_dict["iteration"] = iteration
# Update the best checkpoint based on metric, if provided.
if metric is not None and metric > self._best_metric:
self._best_metric = metric
self._best_ckpt = copy.copy(checkpointable_state_dict)
# Serialize checkpoint corresponding to current iteration.
torch.save(
checkpointable_state_dict,
self.serialization_dir / f"checkpoint_{iteration}.pth",
)
if self._best_metric != -1e-12:
# Serialize best performing checkpoint observed so far.
torch.save(
self._best_ckpt, self.serialization_dir / "checkpoint_best.pth"
)
# Remove earliest checkpoint if there are more on disk.
self._recent_iterations.append(iteration)
if len(self._recent_iterations) > self.keep_recent:
self.remove_earliest_checkpoint()
def _state_dict(self):
r"""Return a dict containing state dict of all checkpointables."""
__state_dict: Dict[str, Any] = {}
for key in self.checkpointables:
if isinstance(
self.checkpointables[key], nn.parallel.DistributedDataParallel
):
__state_dict[key] = self.checkpointables[key].module.state_dict()
else:
__state_dict[key] = self.checkpointables[key].state_dict()
return __state_dict
def remove_earliest_checkpoint(self):
r"""Remove earliest serialized checkpoint from disk."""
earliest_iteration = self._recent_iterations.pop(0)
(self.serialization_dir / f"checkpoint_{earliest_iteration}.pth").unlink()
def load(self, checkpoint_path: str):
r"""
Load a serialized checkpoint from a path. This method will try to find
each of :attr:`checkpointables` in the file and load its state dict.
Since our checkpointables are held as references, this method does not
return them.
Parameters
----------
checkpoint_path: str
Path to a checkpoint serialized by :meth:`step`.
Returns
-------
int
Iteration corresponding to the loaded checkpoint. Useful for
resuming training. This will be -1 in case of best checkpoint,
or if info does not exist.
"""
# Each process will log a message after loading checkpoint.
rank = dist.get_rank()
logger.info(f"Rank {rank}: Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint.pop("iteration", -1)
# Keep flags of all checkpointables to lo which ones were not loaded.
is_loaded = {key: False for key in self.checkpointables}
# Load each checkpointable from checkpoint.
for key in checkpoint:
if key in self.checkpointables:
logger.info(f"Rank {rank}: Loading {key} from {checkpoint_path}")
if isinstance(
self.checkpointables[key], nn.parallel.DistributedDataParallel
):
self.checkpointables[key].module.load_state_dict(checkpoint[key])
else:
self.checkpointables[key].load_state_dict(checkpoint[key])
is_loaded[key] = True
else:
logger.info(f"Rank {rank}: {key} not found in `checkpointables`.")
not_loaded: List[str] = [key for key in is_loaded if not is_loaded[key]]
if len(not_loaded) > 0:
logger.info(
f"Rank {rank}: Checkpointables not found in file: {not_loaded}"
)
return iteration
|