File size: 3,682 Bytes
1b2a9b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import numpy as np
import torch
import time


class IterationCounter():
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.add_argument("--total_nimgs", default=25 *
                            (1000 ** 2), type=int)
        parser.add_argument("--save_freq", default=50000, type=int)
        parser.add_argument("--evaluation_freq", default=50000, type=int)
        parser.add_argument("--print_freq", default=480, type=int)
        parser.add_argument("--display_freq", default=1600, type=int)
        return parser

    def __init__(self, opt):
        self.opt = opt
        self.iter_record_path = os.path.join(
            self.opt.checkpoints_dir, self.opt.name, 'iter.txt')
        self.steps_so_far = 0
        if "unaligned" in opt.dataset_mode:
            self.batch_size = opt.batch_size * 2
        else:
            self.batch_size = opt.batch_size
        self.time_measurements = {}

        automatically_find_resume_iter = opt.isTrain and opt.continue_train \
            and opt.resume_iter == "latest" and opt.pretrained_name is None
        resume_at_specified_iter = opt.isTrain and opt.continue_train \
            and opt.resume_iter.replace("k", "").isnumeric()
        if automatically_find_resume_iter:
            try:
                self.steps_so_far = np.loadtxt(
                    self.iter_record_path, delimiter=',', dtype=int)
                print('Resuming from iteration %d' % (self.steps_so_far))
            except Exception:
                print('Could not load iteration record at %s. '
                      'Starting from beginning.' % self.iter_record_path)
        elif resume_at_specified_iter:
            steps = int(opt.resume_iter.replace("k", ""))
            if "k" in opt.resume_iter:
                steps *= 1000
            self.steps_so_far = steps
        else:
            self.steps_so_far = 0

    def record_one_iteration(self):
        if self.needs_saving():
            np.savetxt(self.iter_record_path,
                       [self.steps_so_far], delimiter=',', fmt='%d')
            print("Saved current iter count at %s" % self.iter_record_path)
        self.steps_so_far += self.batch_size

    def needs_saving(self):
        return (self.steps_so_far % self.opt.save_freq) < self.batch_size

    def needs_evaluation(self):
        return (self.steps_so_far >= self.opt.evaluation_freq) and \
            ((self.steps_so_far % self.opt.evaluation_freq) < self.batch_size)

    def needs_printing(self):
        return (self.steps_so_far % self.opt.print_freq) < self.batch_size

    def needs_displaying(self):
        return (self.steps_so_far % self.opt.display_freq) < self.batch_size

    def completed_training(self):
        return (self.steps_so_far >= self.opt.total_nimgs)

    class TimeMeasurement:
        def __init__(self, name, parent):
            self.name = name
            self.parent = parent

        def __enter__(self):
            self.start_time = time.time()

        def __exit__(self, type, value, traceback):
            torch.cuda.synchronize()
            start_time = self.start_time
            elapsed_time = (time.time() - start_time) / self.parent.batch_size

            if self.name not in self.parent.time_measurements:
                self.parent.time_measurements[self.name] = elapsed_time
            else:
                prev_time = self.parent.time_measurements[self.name]
                updated_time = prev_time * 0.98 + elapsed_time * 0.02
                self.parent.time_measurements[self.name] = updated_time

    def time_measurement(self, name):
        return IterationCounter.TimeMeasurement(name, self)