Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import torch | |
import time | |
class IterationCounter(): | |
def modify_commandline_options(parser, is_train): | |
parser.add_argument("--total_nimgs", default=25 * | |
(1000 ** 2), type=int) | |
parser.add_argument("--save_freq", default=50000, type=int) | |
parser.add_argument("--evaluation_freq", default=50000, type=int) | |
parser.add_argument("--print_freq", default=480, type=int) | |
parser.add_argument("--display_freq", default=1600, type=int) | |
return parser | |
def __init__(self, opt): | |
self.opt = opt | |
self.iter_record_path = os.path.join( | |
self.opt.checkpoints_dir, self.opt.name, 'iter.txt') | |
self.steps_so_far = 0 | |
if "unaligned" in opt.dataset_mode: | |
self.batch_size = opt.batch_size * 2 | |
else: | |
self.batch_size = opt.batch_size | |
self.time_measurements = {} | |
automatically_find_resume_iter = opt.isTrain and opt.continue_train \ | |
and opt.resume_iter == "latest" and opt.pretrained_name is None | |
resume_at_specified_iter = opt.isTrain and opt.continue_train \ | |
and opt.resume_iter.replace("k", "").isnumeric() | |
if automatically_find_resume_iter: | |
try: | |
self.steps_so_far = np.loadtxt( | |
self.iter_record_path, delimiter=',', dtype=int) | |
print('Resuming from iteration %d' % (self.steps_so_far)) | |
except Exception: | |
print('Could not load iteration record at %s. ' | |
'Starting from beginning.' % self.iter_record_path) | |
elif resume_at_specified_iter: | |
steps = int(opt.resume_iter.replace("k", "")) | |
if "k" in opt.resume_iter: | |
steps *= 1000 | |
self.steps_so_far = steps | |
else: | |
self.steps_so_far = 0 | |
def record_one_iteration(self): | |
if self.needs_saving(): | |
np.savetxt(self.iter_record_path, | |
[self.steps_so_far], delimiter=',', fmt='%d') | |
print("Saved current iter count at %s" % self.iter_record_path) | |
self.steps_so_far += self.batch_size | |
def needs_saving(self): | |
return (self.steps_so_far % self.opt.save_freq) < self.batch_size | |
def needs_evaluation(self): | |
return (self.steps_so_far >= self.opt.evaluation_freq) and \ | |
((self.steps_so_far % self.opt.evaluation_freq) < self.batch_size) | |
def needs_printing(self): | |
return (self.steps_so_far % self.opt.print_freq) < self.batch_size | |
def needs_displaying(self): | |
return (self.steps_so_far % self.opt.display_freq) < self.batch_size | |
def completed_training(self): | |
return (self.steps_so_far >= self.opt.total_nimgs) | |
class TimeMeasurement: | |
def __init__(self, name, parent): | |
self.name = name | |
self.parent = parent | |
def __enter__(self): | |
self.start_time = time.time() | |
def __exit__(self, type, value, traceback): | |
torch.cuda.synchronize() | |
start_time = self.start_time | |
elapsed_time = (time.time() - start_time) / self.parent.batch_size | |
if self.name not in self.parent.time_measurements: | |
self.parent.time_measurements[self.name] = elapsed_time | |
else: | |
prev_time = self.parent.time_measurements[self.name] | |
updated_time = prev_time * 0.98 + elapsed_time * 0.02 | |
self.parent.time_measurements[self.name] = updated_time | |
def time_measurement(self, name): | |
return IterationCounter.TimeMeasurement(name, self) | |