Spaces:
Runtime error
Runtime error
File size: 3,682 Bytes
1b2a9b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import os
import numpy as np
import torch
import time
class IterationCounter():
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument("--total_nimgs", default=25 *
(1000 ** 2), type=int)
parser.add_argument("--save_freq", default=50000, type=int)
parser.add_argument("--evaluation_freq", default=50000, type=int)
parser.add_argument("--print_freq", default=480, type=int)
parser.add_argument("--display_freq", default=1600, type=int)
return parser
def __init__(self, opt):
self.opt = opt
self.iter_record_path = os.path.join(
self.opt.checkpoints_dir, self.opt.name, 'iter.txt')
self.steps_so_far = 0
if "unaligned" in opt.dataset_mode:
self.batch_size = opt.batch_size * 2
else:
self.batch_size = opt.batch_size
self.time_measurements = {}
automatically_find_resume_iter = opt.isTrain and opt.continue_train \
and opt.resume_iter == "latest" and opt.pretrained_name is None
resume_at_specified_iter = opt.isTrain and opt.continue_train \
and opt.resume_iter.replace("k", "").isnumeric()
if automatically_find_resume_iter:
try:
self.steps_so_far = np.loadtxt(
self.iter_record_path, delimiter=',', dtype=int)
print('Resuming from iteration %d' % (self.steps_so_far))
except Exception:
print('Could not load iteration record at %s. '
'Starting from beginning.' % self.iter_record_path)
elif resume_at_specified_iter:
steps = int(opt.resume_iter.replace("k", ""))
if "k" in opt.resume_iter:
steps *= 1000
self.steps_so_far = steps
else:
self.steps_so_far = 0
def record_one_iteration(self):
if self.needs_saving():
np.savetxt(self.iter_record_path,
[self.steps_so_far], delimiter=',', fmt='%d')
print("Saved current iter count at %s" % self.iter_record_path)
self.steps_so_far += self.batch_size
def needs_saving(self):
return (self.steps_so_far % self.opt.save_freq) < self.batch_size
def needs_evaluation(self):
return (self.steps_so_far >= self.opt.evaluation_freq) and \
((self.steps_so_far % self.opt.evaluation_freq) < self.batch_size)
def needs_printing(self):
return (self.steps_so_far % self.opt.print_freq) < self.batch_size
def needs_displaying(self):
return (self.steps_so_far % self.opt.display_freq) < self.batch_size
def completed_training(self):
return (self.steps_so_far >= self.opt.total_nimgs)
class TimeMeasurement:
def __init__(self, name, parent):
self.name = name
self.parent = parent
def __enter__(self):
self.start_time = time.time()
def __exit__(self, type, value, traceback):
torch.cuda.synchronize()
start_time = self.start_time
elapsed_time = (time.time() - start_time) / self.parent.batch_size
if self.name not in self.parent.time_measurements:
self.parent.time_measurements[self.name] = elapsed_time
else:
prev_time = self.parent.time_measurements[self.name]
updated_time = prev_time * 0.98 + elapsed_time * 0.02
self.parent.time_measurements[self.name] = updated_time
def time_measurement(self, name):
return IterationCounter.TimeMeasurement(name, self)
|