File size: 5,389 Bytes
2252f3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from __future__ import division
import os
import torch
import datetime
import logging
logger = logging.getLogger(__name__)
class CheckpointSaver():
"""Class that handles saving and loading checkpoints during training."""
def __init__(self, save_dir, save_steps=1000, overwrite=False):
self.save_dir = os.path.abspath(save_dir)
self.save_steps = save_steps
self.overwrite = overwrite
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.get_latest_checkpoint()
return
def exists_checkpoint(self, checkpoint_file=None):
"""Check if a checkpoint exists in the current directory."""
if checkpoint_file is None:
return False if self.latest_checkpoint is None else True
else:
return os.path.isfile(checkpoint_file)
def save_checkpoint(
self,
models,
optimizers,
epoch,
batch_idx,
batch_size,
total_step_count,
is_best=False,
save_by_step=False,
interval=5,
with_optimizer=True
):
"""Save checkpoint."""
timestamp = datetime.datetime.now()
if self.overwrite:
checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_latest.pt'))
elif save_by_step:
checkpoint_filename = os.path.abspath(
os.path.join(self.save_dir, '{:08d}.pt'.format(total_step_count))
)
else:
if epoch % interval == 0:
checkpoint_filename = os.path.abspath(
os.path.join(self.save_dir, f'model_epoch_{epoch:02d}.pt')
)
else:
checkpoint_filename = None
checkpoint = {}
for model in models:
model_dict = models[model].state_dict()
for k in list(model_dict.keys()):
if '.smpl.' in k:
del model_dict[k]
checkpoint[model] = model_dict
if with_optimizer:
for optimizer in optimizers:
checkpoint[optimizer] = optimizers[optimizer].state_dict()
checkpoint['epoch'] = epoch
checkpoint['batch_idx'] = batch_idx
checkpoint['batch_size'] = batch_size
checkpoint['total_step_count'] = total_step_count
print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx)
if checkpoint_filename is not None:
torch.save(checkpoint, checkpoint_filename)
print('Saving checkpoint file [' + checkpoint_filename + ']')
if is_best: # save the best
checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_best.pt'))
torch.save(checkpoint, checkpoint_filename)
print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx)
print('Saving checkpoint file [' + checkpoint_filename + ']')
torch.save(checkpoint, checkpoint_filename)
print('Saved checkpoint file [' + checkpoint_filename + ']')
def load_checkpoint(self, models, optimizers, checkpoint_file=None):
"""Load a checkpoint."""
if checkpoint_file is None:
logger.info('Loading latest checkpoint [' + self.latest_checkpoint + ']')
checkpoint_file = self.latest_checkpoint
checkpoint = torch.load(checkpoint_file)
for model in models:
if model in checkpoint:
model_dict = models[model].state_dict()
pretrained_dict = {
k: v
for k, v in checkpoint[model].items() if k in model_dict.keys()
}
model_dict.update(pretrained_dict)
models[model].load_state_dict(model_dict)
# models[model].load_state_dict(checkpoint[model])
for optimizer in optimizers:
if optimizer in checkpoint:
optimizers[optimizer].load_state_dict(checkpoint[optimizer])
return {
'epoch': checkpoint['epoch'],
'batch_idx': checkpoint['batch_idx'],
'batch_size': checkpoint['batch_size'],
'total_step_count': checkpoint['total_step_count']
}
def get_latest_checkpoint(self):
"""Get filename of latest checkpoint if it exists."""
checkpoint_list = []
for dirpath, dirnames, filenames in os.walk(self.save_dir):
for filename in filenames:
if filename.endswith('.pt'):
checkpoint_list.append(os.path.abspath(os.path.join(dirpath, filename)))
# sort
import re
def atof(text):
try:
retval = float(text)
except ValueError:
retval = text
return retval
def natural_keys(text):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
float regex comes from https://stackoverflow.com/a/12643073/190597
'''
return [atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text)]
checkpoint_list.sort(key=natural_keys)
self.latest_checkpoint = None if (len(checkpoint_list) == 0) else checkpoint_list[-1]
return
|