import numpy as np import os, sys, time import torch import random import string import yaml import utils.util as util import time from utils.util import EasyDict as edict # torch.backends.cudnn.enabled = False # torch.backends.cudnn.benchmark = False # torch.backends.cudnn.deterministic = True def parse_arguments(args): # parse from command line (syntax: --key1.key2.key3=value) opt_cmd = {} for arg in args: assert(arg.startswith("--")) if "=" not in arg[2:]: # --key means key=True, --key! means key=False key_str, value = (arg[2:-1], "false") if arg[-1]=="!" else (arg[2:], "true") else: key_str, value = arg[2:].split("=") keys_sub = key_str.split(".") opt_sub = opt_cmd for k in keys_sub[:-1]: if k not in opt_sub: opt_sub[k] = {} opt_sub = opt_sub[k] # if opt_cmd['key1']['key2']['key3'] already exist for key1.key2.key3, print key3 as error msg assert keys_sub[-1] not in opt_sub, keys_sub[-1] opt_sub[keys_sub[-1]] = yaml.safe_load(value) opt_cmd = edict(opt_cmd) return opt_cmd def set(opt_cmd={}, verbose=True, safe_check=True): print("setting configurations...") fname = opt_cmd.yaml # load from yaml file opt_base = load_options(fname) # override with command line arguments opt = override_options(opt_base, opt_cmd, key_stack=[], safe_check=safe_check) process_options(opt) if verbose: def print_options(opt, level=0): for key, value in sorted(opt.items()): if isinstance(value, (dict, edict)): print(" "*level+"* "+key+":") print_options(value, level+1) else: print(" "*level+"* "+key+":", value) print_options(opt) return opt def load_options(fname): with open(fname) as file: opt = edict(yaml.safe_load(file)) if "_parent_" in opt: # load parent yaml file(s) as base options parent_fnames = opt.pop("_parent_") if type(parent_fnames) is str: parent_fnames = [parent_fnames] for parent_fname in parent_fnames: opt_parent = load_options(parent_fname) opt_parent = override_options(opt_parent, opt, key_stack=[]) opt = opt_parent print("loading {}...".format(fname)) return opt def override_options(opt, opt_over, key_stack=None, safe_check=False): for key, value in opt_over.items(): if isinstance(value, dict): # parse child options (until leaf nodes are reached) opt[key] = override_options(opt.get(key, dict()), value, key_stack=key_stack+[key], safe_check=safe_check) else: # ensure command line argument to override is also in yaml file if safe_check and key not in opt: add_new = None while add_new not in ["y", "n"]: key_str = ".".join(key_stack+[key]) add_new = input("\"{}\" not found in original opt, add? (y/n) ".format(key_str)) if add_new=="n": print("safe exiting...") exit() opt[key] = value return opt def process_options(opt): # set seed if opt.seed is not None: random.seed(opt.seed) np.random.seed(opt.seed) torch.manual_seed(opt.seed) torch.cuda.manual_seed_all(opt.seed) else: # create random string as run ID randkey = "".join(random.choice(string.ascii_uppercase) for _ in range(4)) opt.name += "_{}".format(randkey) # other default options opt.output_path = "{0}/{1}/{2}".format(opt.output_root, opt.group, opt.name) os.makedirs(opt.output_path, exist_ok=True) opt.H, opt.W = opt.image_size if opt.freq.eval is None: opt.freq.eval = max(opt.max_epoch // 20, 1) if 'loss_weight' in opt: opt.get_depth = False opt.get_normal = False def save_options_file(opt): opt_fname = "{}/options.yaml".format(opt.output_path) if os.path.isfile(opt_fname): with open(opt_fname) as file: opt_old = yaml.safe_load(file) if opt!=opt_old: # prompt if options are not identical opt_new_fname = "{}/options_temp.yaml".format(opt.output_path) with open(opt_new_fname, "w") as file: yaml.safe_dump(util.to_dict(opt), file, default_flow_style=False, indent=4) print("existing options file found (different from current one)...") os.system("diff {} {}".format(opt_fname, opt_new_fname)) os.system("rm {}".format(opt_new_fname)) if not opt.debug: print("please cancel within 10 seconds if you do not want to override...") time.sleep(10) else: print("existing options file found (identical)") else: print("(creating new options file...)") with open(opt_fname, "w") as file: yaml.safe_dump(util.to_dict(opt), file, default_flow_style=False, indent=4)