ECON / lib /pymafx /utils /saver.py
Yuliang's picture
remove MeshLab dependency with Open3D
fb140f6
raw
history blame
5.39 kB
from __future__ import division
import os
import torch
import datetime
import logging
logger = logging.getLogger(__name__)
class CheckpointSaver():
"""Class that handles saving and loading checkpoints during training."""
def __init__(self, save_dir, save_steps=1000, overwrite=False):
self.save_dir = os.path.abspath(save_dir)
self.save_steps = save_steps
self.overwrite = overwrite
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.get_latest_checkpoint()
return
def exists_checkpoint(self, checkpoint_file=None):
"""Check if a checkpoint exists in the current directory."""
if checkpoint_file is None:
return False if self.latest_checkpoint is None else True
else:
return os.path.isfile(checkpoint_file)
def save_checkpoint(
self,
models,
optimizers,
epoch,
batch_idx,
batch_size,
total_step_count,
is_best=False,
save_by_step=False,
interval=5,
with_optimizer=True
):
"""Save checkpoint."""
timestamp = datetime.datetime.now()
if self.overwrite:
checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_latest.pt'))
elif save_by_step:
checkpoint_filename = os.path.abspath(
os.path.join(self.save_dir, '{:08d}.pt'.format(total_step_count))
)
else:
if epoch % interval == 0:
checkpoint_filename = os.path.abspath(
os.path.join(self.save_dir, f'model_epoch_{epoch:02d}.pt')
)
else:
checkpoint_filename = None
checkpoint = {}
for model in models:
model_dict = models[model].state_dict()
for k in list(model_dict.keys()):
if '.smpl.' in k:
del model_dict[k]
checkpoint[model] = model_dict
if with_optimizer:
for optimizer in optimizers:
checkpoint[optimizer] = optimizers[optimizer].state_dict()
checkpoint['epoch'] = epoch
checkpoint['batch_idx'] = batch_idx
checkpoint['batch_size'] = batch_size
checkpoint['total_step_count'] = total_step_count
print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx)
if checkpoint_filename is not None:
torch.save(checkpoint, checkpoint_filename)
print('Saving checkpoint file [' + checkpoint_filename + ']')
if is_best: # save the best
checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_best.pt'))
torch.save(checkpoint, checkpoint_filename)
print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx)
print('Saving checkpoint file [' + checkpoint_filename + ']')
torch.save(checkpoint, checkpoint_filename)
print('Saved checkpoint file [' + checkpoint_filename + ']')
def load_checkpoint(self, models, optimizers, checkpoint_file=None):
"""Load a checkpoint."""
if checkpoint_file is None:
logger.info('Loading latest checkpoint [' + self.latest_checkpoint + ']')
checkpoint_file = self.latest_checkpoint
checkpoint = torch.load(checkpoint_file)
for model in models:
if model in checkpoint:
model_dict = models[model].state_dict()
pretrained_dict = {
k: v
for k, v in checkpoint[model].items() if k in model_dict.keys()
}
model_dict.update(pretrained_dict)
models[model].load_state_dict(model_dict)
# models[model].load_state_dict(checkpoint[model])
for optimizer in optimizers:
if optimizer in checkpoint:
optimizers[optimizer].load_state_dict(checkpoint[optimizer])
return {
'epoch': checkpoint['epoch'],
'batch_idx': checkpoint['batch_idx'],
'batch_size': checkpoint['batch_size'],
'total_step_count': checkpoint['total_step_count']
}
def get_latest_checkpoint(self):
"""Get filename of latest checkpoint if it exists."""
checkpoint_list = []
for dirpath, dirnames, filenames in os.walk(self.save_dir):
for filename in filenames:
if filename.endswith('.pt'):
checkpoint_list.append(os.path.abspath(os.path.join(dirpath, filename)))
# sort
import re
def atof(text):
try:
retval = float(text)
except ValueError:
retval = text
return retval
def natural_keys(text):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
float regex comes from https://stackoverflow.com/a/12643073/190597
'''
return [atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text)]
checkpoint_list.sort(key=natural_keys)
self.latest_checkpoint = None if (len(checkpoint_list) == 0) else checkpoint_list[-1]
return