Spaces:
Runtime error
Runtime error
import copy | |
import pathlib | |
from typing import Any, Dict, List, Optional | |
from loguru import logger | |
import torch | |
from torch import nn | |
import virtex.utils.distributed as dist | |
class CheckpointManager(object): | |
r""" | |
A helper class to periodically serialize models and other checkpointable | |
objects (optimizers, LR schedulers etc., which implement ``state_dict`` | |
method) during training, and optionally record best performing checkpoint | |
based on an observed metric. | |
.. note:: | |
For :class:`~torch.nn.parallel.DistributedDataParallel` objects, | |
``state_dict`` of internal model is serialized. | |
.. note:: | |
The observed metric for keeping best checkpoint is assumed "higher is | |
better", flip the sign if otherwise. | |
Parameters | |
---------- | |
serialization_dir: str | |
Path to a directory to save checkpoints. | |
keep_recent: int, optional (default = 100) | |
Number of recent ``k`` checkpoints to keep on disk. Older checkpoints | |
will be removed. Set to a very large value for keeping all checkpoints. | |
checkpointables: Any | |
Keyword arguments with any checkpointable objects, for example: model, | |
optimizer, learning rate scheduler. | |
Examples | |
-------- | |
>>> model = torch.nn.Linear(10, 2) | |
>>> optimizer = torch.optim.Adam(model.parameters()) | |
>>> ckpt_manager = CheckpointManager("/tmp", model=model, optimizer=optimizer) | |
>>> num_epochs = 20 | |
>>> for epoch in range(num_epochs): | |
... train(model) | |
... val_loss = validate(model) | |
... ckpt_manager.step(- val_loss, epoch) | |
""" | |
def __init__( | |
self, | |
serialization_dir: str = "/tmp", | |
keep_recent: int = 200, | |
**checkpointables: Any, | |
): | |
self.serialization_dir = pathlib.Path(serialization_dir) | |
self.keep_recent = keep_recent | |
# Shallow copy, keeps references to tensors as original objects. | |
self.checkpointables = copy.copy(checkpointables) | |
# Initialize members to hold state dict of best checkpoint and its | |
# performance. | |
self._best_metric: float = -1e-12 | |
self._best_ckpt: Dict[str, Any] = {} | |
# Keep epoch/iteration numbers of recently saved 'k' checkpoints. | |
self._recent_iterations: List[int] = [] | |
def step(self, iteration: int, metric: Optional[float] = None): | |
r""" | |
Serialize checkpoint and update best checkpoint based on metric. Keys | |
in serialized checkpoint match those in :attr:`checkpointables`. | |
Parameters | |
---------- | |
iteration: int | |
Current training iteration. Will be saved with other checkpointables. | |
metric: float, optional (default = None) | |
Observed metric (higher is better) for keeping track of best | |
checkpoint. If this is ``None``, best chckpoint will not be | |
recorded/updated. | |
""" | |
checkpointable_state_dict: Dict[str, Any] = self._state_dict() | |
# We also checkpoint current iteration. | |
checkpointable_state_dict["iteration"] = iteration | |
# Update the best checkpoint based on metric, if provided. | |
if metric is not None and metric > self._best_metric: | |
self._best_metric = metric | |
self._best_ckpt = copy.copy(checkpointable_state_dict) | |
# Serialize checkpoint corresponding to current iteration. | |
torch.save( | |
checkpointable_state_dict, | |
self.serialization_dir / f"checkpoint_{iteration}.pth", | |
) | |
if self._best_metric != -1e-12: | |
# Serialize best performing checkpoint observed so far. | |
torch.save( | |
self._best_ckpt, self.serialization_dir / "checkpoint_best.pth" | |
) | |
# Remove earliest checkpoint if there are more on disk. | |
self._recent_iterations.append(iteration) | |
if len(self._recent_iterations) > self.keep_recent: | |
self.remove_earliest_checkpoint() | |
def _state_dict(self): | |
r"""Return a dict containing state dict of all checkpointables.""" | |
__state_dict: Dict[str, Any] = {} | |
for key in self.checkpointables: | |
if isinstance( | |
self.checkpointables[key], nn.parallel.DistributedDataParallel | |
): | |
__state_dict[key] = self.checkpointables[key].module.state_dict() | |
else: | |
__state_dict[key] = self.checkpointables[key].state_dict() | |
return __state_dict | |
def remove_earliest_checkpoint(self): | |
r"""Remove earliest serialized checkpoint from disk.""" | |
earliest_iteration = self._recent_iterations.pop(0) | |
(self.serialization_dir / f"checkpoint_{earliest_iteration}.pth").unlink() | |
def load(self, checkpoint_path: str): | |
r""" | |
Load a serialized checkpoint from a path. This method will try to find | |
each of :attr:`checkpointables` in the file and load its state dict. | |
Since our checkpointables are held as references, this method does not | |
return them. | |
Parameters | |
---------- | |
checkpoint_path: str | |
Path to a checkpoint serialized by :meth:`step`. | |
Returns | |
------- | |
int | |
Iteration corresponding to the loaded checkpoint. Useful for | |
resuming training. This will be -1 in case of best checkpoint, | |
or if info does not exist. | |
""" | |
# Each process will log a message after loading checkpoint. | |
rank = dist.get_rank() | |
logger.info(f"Rank {rank}: Loading checkpoint from {checkpoint_path}") | |
checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
iteration = checkpoint.pop("iteration", -1) | |
# Keep flags of all checkpointables to lo which ones were not loaded. | |
is_loaded = {key: False for key in self.checkpointables} | |
# Load each checkpointable from checkpoint. | |
for key in checkpoint: | |
if key in self.checkpointables: | |
logger.info(f"Rank {rank}: Loading {key} from {checkpoint_path}") | |
if isinstance( | |
self.checkpointables[key], nn.parallel.DistributedDataParallel | |
): | |
self.checkpointables[key].module.load_state_dict(checkpoint[key]) | |
else: | |
self.checkpointables[key].load_state_dict(checkpoint[key]) | |
is_loaded[key] = True | |
else: | |
logger.info(f"Rank {rank}: {key} not found in `checkpointables`.") | |
not_loaded: List[str] = [key for key in is_loaded if not is_loaded[key]] | |
if len(not_loaded) > 0: | |
logger.info( | |
f"Rank {rank}: Checkpointables not found in file: {not_loaded}" | |
) | |
return iteration | |