Spaces:
Running
on
T4
Running
on
T4
File size: 6,174 Bytes
28c6826 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
""" Checkpoint Saver
Track top-n training checkpoints and maintain recovery checkpoints on specified intervals.
Hacked together by / Copyright 2020 Ross Wightman
"""
import glob
import operator
import os
import logging
import torch
from .model import unwrap_model, get_state_dict
_logger = logging.getLogger(__name__)
class CheckpointSaver:
def __init__(
self,
model,
optimizer,
args=None,
model_ema=None,
amp_scaler=None,
checkpoint_prefix='checkpoint',
recovery_prefix='recovery',
checkpoint_dir='',
recovery_dir='',
decreasing=False,
max_history=10,
unwrap_fn=unwrap_model):
# objects to save state_dicts of
self.model = model
self.optimizer = optimizer
self.args = args
self.model_ema = model_ema
self.amp_scaler = amp_scaler
# state
self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
self.best_epoch = None
self.best_metric = None
self.curr_recovery_file = ''
self.last_recovery_file = ''
# config
self.checkpoint_dir = checkpoint_dir
self.recovery_dir = recovery_dir
self.save_prefix = checkpoint_prefix
self.recovery_prefix = recovery_prefix
self.extension = '.pth.tar'
self.decreasing = decreasing # a lower metric is better if True
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
self.max_history = max_history
self.unwrap_fn = unwrap_fn
assert self.max_history >= 1
def save_checkpoint(self, epoch, metric=None):
assert epoch >= 0
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
self._save(tmp_save_path, epoch, metric)
if os.path.exists(last_save_path):
os.unlink(last_save_path) # required for Windows support.
os.rename(tmp_save_path, last_save_path)
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
if (len(self.checkpoint_files) < self.max_history
or metric is None or self.cmp(metric, worst_file[1])):
if len(self.checkpoint_files) >= self.max_history:
self._cleanup_checkpoints(1)
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
save_path = os.path.join(self.checkpoint_dir, filename)
os.link(last_save_path, save_path)
self.checkpoint_files.append((save_path, metric))
self.checkpoint_files = sorted(
self.checkpoint_files, key=lambda x: x[1],
reverse=not self.decreasing) # sort in descending order if a lower metric is not better
checkpoints_str = "Current checkpoints:\n"
for c in self.checkpoint_files:
checkpoints_str += ' {}\n'.format(c)
_logger.info(checkpoints_str)
if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
self.best_epoch = epoch
self.best_metric = metric
best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension)
if os.path.exists(best_save_path):
os.unlink(best_save_path)
os.link(last_save_path, best_save_path)
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
def _save(self, save_path, epoch, metric=None):
save_state = {
'epoch': epoch,
'arch': type(self.model).__name__.lower(),
'state_dict': get_state_dict(self.model, self.unwrap_fn),
'optimizer': self.optimizer.state_dict(),
'version': 2, # version < 2 increments epoch before save
}
if self.args is not None:
save_state['arch'] = self.args.model
save_state['args'] = self.args
if self.amp_scaler is not None:
save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict()
if self.model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn)
if metric is not None:
save_state['metric'] = metric
torch.save(save_state, save_path)
def _cleanup_checkpoints(self, trim=0):
trim = min(len(self.checkpoint_files), trim)
delete_index = self.max_history - trim
if delete_index <= 0 or len(self.checkpoint_files) <= delete_index:
return
to_delete = self.checkpoint_files[delete_index:]
for d in to_delete:
try:
_logger.debug("Cleaning checkpoint: {}".format(d))
os.remove(d[0])
except Exception as e:
_logger.error("Exception '{}' while deleting checkpoint".format(e))
self.checkpoint_files = self.checkpoint_files[:delete_index]
def save_recovery(self, epoch, batch_idx=0):
assert epoch >= 0
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
save_path = os.path.join(self.recovery_dir, filename)
self._save(save_path, epoch)
if os.path.exists(self.last_recovery_file):
try:
_logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
os.remove(self.last_recovery_file)
except Exception as e:
_logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file))
self.last_recovery_file = self.curr_recovery_file
self.curr_recovery_file = save_path
def find_recovery(self):
recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
files = glob.glob(recovery_path + '*' + self.extension)
files = sorted(files)
if len(files):
return files[0]
else:
return ''
|