File size: 5,389 Bytes
2252f3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from __future__ import division
import os
import torch
import datetime
import logging

logger = logging.getLogger(__name__)


class CheckpointSaver():
    """Class that handles saving and loading checkpoints during training."""
    def __init__(self, save_dir, save_steps=1000, overwrite=False):
        self.save_dir = os.path.abspath(save_dir)
        self.save_steps = save_steps
        self.overwrite = overwrite
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        self.get_latest_checkpoint()
        return

    def exists_checkpoint(self, checkpoint_file=None):
        """Check if a checkpoint exists in the current directory."""
        if checkpoint_file is None:
            return False if self.latest_checkpoint is None else True
        else:
            return os.path.isfile(checkpoint_file)

    def save_checkpoint(
        self,
        models,
        optimizers,
        epoch,
        batch_idx,
        batch_size,
        total_step_count,
        is_best=False,
        save_by_step=False,
        interval=5,
        with_optimizer=True
    ):
        """Save checkpoint."""
        timestamp = datetime.datetime.now()
        if self.overwrite:
            checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_latest.pt'))
        elif save_by_step:
            checkpoint_filename = os.path.abspath(
                os.path.join(self.save_dir, '{:08d}.pt'.format(total_step_count))
            )
        else:
            if epoch % interval == 0:
                checkpoint_filename = os.path.abspath(
                    os.path.join(self.save_dir, f'model_epoch_{epoch:02d}.pt')
                )
            else:
                checkpoint_filename = None

        checkpoint = {}
        for model in models:
            model_dict = models[model].state_dict()
            for k in list(model_dict.keys()):
                if '.smpl.' in k:
                    del model_dict[k]
            checkpoint[model] = model_dict
        if with_optimizer:
            for optimizer in optimizers:
                checkpoint[optimizer] = optimizers[optimizer].state_dict()
        checkpoint['epoch'] = epoch
        checkpoint['batch_idx'] = batch_idx
        checkpoint['batch_size'] = batch_size
        checkpoint['total_step_count'] = total_step_count
        print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx)

        if checkpoint_filename is not None:
            torch.save(checkpoint, checkpoint_filename)
            print('Saving checkpoint file [' + checkpoint_filename + ']')
        if is_best:    # save the best
            checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_best.pt'))
            torch.save(checkpoint, checkpoint_filename)
            print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx)
            print('Saving checkpoint file [' + checkpoint_filename + ']')
            torch.save(checkpoint, checkpoint_filename)
            print('Saved checkpoint file [' + checkpoint_filename + ']')

    def load_checkpoint(self, models, optimizers, checkpoint_file=None):
        """Load a checkpoint."""
        if checkpoint_file is None:
            logger.info('Loading latest checkpoint [' + self.latest_checkpoint + ']')
            checkpoint_file = self.latest_checkpoint
        checkpoint = torch.load(checkpoint_file)
        for model in models:
            if model in checkpoint:
                model_dict = models[model].state_dict()
                pretrained_dict = {
                    k: v
                    for k, v in checkpoint[model].items() if k in model_dict.keys()
                }
                model_dict.update(pretrained_dict)
                models[model].load_state_dict(model_dict)

                # models[model].load_state_dict(checkpoint[model])
        for optimizer in optimizers:
            if optimizer in checkpoint:
                optimizers[optimizer].load_state_dict(checkpoint[optimizer])
        return {
            'epoch': checkpoint['epoch'],
            'batch_idx': checkpoint['batch_idx'],
            'batch_size': checkpoint['batch_size'],
            'total_step_count': checkpoint['total_step_count']
        }

    def get_latest_checkpoint(self):
        """Get filename of latest checkpoint if it exists."""
        checkpoint_list = []
        for dirpath, dirnames, filenames in os.walk(self.save_dir):
            for filename in filenames:
                if filename.endswith('.pt'):
                    checkpoint_list.append(os.path.abspath(os.path.join(dirpath, filename)))
        # sort
        import re

        def atof(text):
            try:
                retval = float(text)
            except ValueError:
                retval = text
            return retval

        def natural_keys(text):
            '''
            alist.sort(key=natural_keys) sorts in human order
            http://nedbatchelder.com/blog/200712/human_sorting.html
            (See Toothy's implementation in the comments)
            float regex comes from https://stackoverflow.com/a/12643073/190597
            '''
            return [atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text)]

        checkpoint_list.sort(key=natural_keys)
        self.latest_checkpoint = None if (len(checkpoint_list) == 0) else checkpoint_list[-1]
        return