TextureScraping / swapae /util /iter_counter.py
sunshineatnoon
Add application file
1b2a9b1
import os
import numpy as np
import torch
import time
class IterationCounter():
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument("--total_nimgs", default=25 *
(1000 ** 2), type=int)
parser.add_argument("--save_freq", default=50000, type=int)
parser.add_argument("--evaluation_freq", default=50000, type=int)
parser.add_argument("--print_freq", default=480, type=int)
parser.add_argument("--display_freq", default=1600, type=int)
return parser
def __init__(self, opt):
self.opt = opt
self.iter_record_path = os.path.join(
self.opt.checkpoints_dir, self.opt.name, 'iter.txt')
self.steps_so_far = 0
if "unaligned" in opt.dataset_mode:
self.batch_size = opt.batch_size * 2
else:
self.batch_size = opt.batch_size
self.time_measurements = {}
automatically_find_resume_iter = opt.isTrain and opt.continue_train \
and opt.resume_iter == "latest" and opt.pretrained_name is None
resume_at_specified_iter = opt.isTrain and opt.continue_train \
and opt.resume_iter.replace("k", "").isnumeric()
if automatically_find_resume_iter:
try:
self.steps_so_far = np.loadtxt(
self.iter_record_path, delimiter=',', dtype=int)
print('Resuming from iteration %d' % (self.steps_so_far))
except Exception:
print('Could not load iteration record at %s. '
'Starting from beginning.' % self.iter_record_path)
elif resume_at_specified_iter:
steps = int(opt.resume_iter.replace("k", ""))
if "k" in opt.resume_iter:
steps *= 1000
self.steps_so_far = steps
else:
self.steps_so_far = 0
def record_one_iteration(self):
if self.needs_saving():
np.savetxt(self.iter_record_path,
[self.steps_so_far], delimiter=',', fmt='%d')
print("Saved current iter count at %s" % self.iter_record_path)
self.steps_so_far += self.batch_size
def needs_saving(self):
return (self.steps_so_far % self.opt.save_freq) < self.batch_size
def needs_evaluation(self):
return (self.steps_so_far >= self.opt.evaluation_freq) and \
((self.steps_so_far % self.opt.evaluation_freq) < self.batch_size)
def needs_printing(self):
return (self.steps_so_far % self.opt.print_freq) < self.batch_size
def needs_displaying(self):
return (self.steps_so_far % self.opt.display_freq) < self.batch_size
def completed_training(self):
return (self.steps_so_far >= self.opt.total_nimgs)
class TimeMeasurement:
def __init__(self, name, parent):
self.name = name
self.parent = parent
def __enter__(self):
self.start_time = time.time()
def __exit__(self, type, value, traceback):
torch.cuda.synchronize()
start_time = self.start_time
elapsed_time = (time.time() - start_time) / self.parent.batch_size
if self.name not in self.parent.time_measurements:
self.parent.time_measurements[self.name] = elapsed_time
else:
prev_time = self.parent.time_measurements[self.name]
updated_time = prev_time * 0.98 + elapsed_time * 0.02
self.parent.time_measurements[self.name] = updated_time
def time_measurement(self, name):
return IterationCounter.TimeMeasurement(name, self)