""" This file runs the main training/val loop """ import os import json import math import sys import pprint import torch from argparse import Namespace sys.path.append(".") sys.path.append("..") from options.train_options import TrainOptions from training.coach import Coach def main(): opts = TrainOptions().parse() previous_train_ckpt = None if opts.resume_training_from_ckpt: opts, previous_train_ckpt = load_train_checkpoint(opts) else: setup_progressive_steps(opts) create_initial_experiment_dir(opts) coach = Coach(opts, previous_train_ckpt) coach.train() def load_train_checkpoint(opts): train_ckpt_path = opts.resume_training_from_ckpt previous_train_ckpt = torch.load(opts.resume_training_from_ckpt, map_location='cpu') new_opts_dict = vars(opts) opts = previous_train_ckpt['opts'] opts['resume_training_from_ckpt'] = train_ckpt_path update_new_configs(opts, new_opts_dict) pprint.pprint(opts) opts = Namespace(**opts) if opts.sub_exp_dir is not None: sub_exp_dir = opts.sub_exp_dir opts.exp_dir = os.path.join(opts.exp_dir, sub_exp_dir) create_initial_experiment_dir(opts) return opts, previous_train_ckpt def setup_progressive_steps(opts): log_size = int(math.log(opts.stylegan_size, 2)) num_style_layers = 2*log_size - 2 num_deltas = num_style_layers - 1 if opts.progressive_start is not None: # If progressive delta training opts.progressive_steps = [0] next_progressive_step = opts.progressive_start for i in range(num_deltas): opts.progressive_steps.append(next_progressive_step) next_progressive_step += opts.progressive_step_every assert opts.progressive_steps is None or is_valid_progressive_steps(opts, num_style_layers), \ "Invalid progressive training input" def is_valid_progressive_steps(opts, num_style_layers): return len(opts.progressive_steps) == num_style_layers and opts.progressive_steps[0] == 0 def create_initial_experiment_dir(opts): if os.path.exists(opts.exp_dir): raise Exception('Oops... {} already exists'.format(opts.exp_dir)) os.makedirs(opts.exp_dir) opts_dict = vars(opts) pprint.pprint(opts_dict) with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: json.dump(opts_dict, f, indent=4, sort_keys=True) def update_new_configs(ckpt_opts, new_opts): for k, v in new_opts.items(): if k not in ckpt_opts: ckpt_opts[k] = v if new_opts['update_param_list']: for param in new_opts['update_param_list']: ckpt_opts[param] = new_opts[param] if __name__ == '__main__': main()