import yaml import copy from typing import Union class BaseConfig(): def __init__(self): self.__config_dict = {} self.__check_func_dict = {} is_greater_than_0 = lambda x: x > 0 # common config self._add_option('common', 'name', str, 'style_master') self._add_option('common', 'model', str, 'cycle_gan') self._add_option('common', 'phase', str, 'train', check_func=lambda x: x in ['train', 'test']) self._add_option('common', 'gpu_ids', list, [0]) self._add_option('common', 'verbose', bool, False) # model config self._add_option('model', 'input_nc', int, 3, check_func=is_greater_than_0) self._add_option('model', 'output_nc', int, 3, check_func=is_greater_than_0) # dataset config # common dataset options self._add_option('dataset', 'use_absolute_datafile', bool, True) self._add_option('dataset', 'batch_size', int, 1, check_func=is_greater_than_0) self._add_option('dataset', 'n_threads', int, 4, check_func=is_greater_than_0) self._add_option('dataset', 'dataroot', str, './') self._add_option('dataset', 'drop_last', bool, False) self._add_option('dataset', 'landmark_scale', list, None) self._add_option('dataset', 'check_all_data', bool, False) self._add_option('dataset', 'accept_data_error', bool, True) # Upon loading a bad data, if this is true, # dataloader will throw an exception and # load the next good data. # If this is false, process will crash. self._add_option('dataset', 'train_data', dict, {}) self._add_option('dataset', 'val_data', dict, {}) # paired data config self._add_option('dataset', 'paired_trainA_folder', str, '') self._add_option('dataset', 'paired_trainB_folder', str, '') self._add_option('dataset', 'paired_train_filelist', str, '') self._add_option('dataset', 'paired_valA_folder', str, '') self._add_option('dataset', 'paired_valB_folder', str, '') self._add_option('dataset', 'paired_val_filelist', str, '') # unpaired data config self._add_option('dataset', 'unpaired_trainA_folder', str, '') self._add_option('dataset', 'unpaired_trainB_folder', str, '') self._add_option('dataset', 'unpaired_trainA_filelist', str, '') self._add_option('dataset', 'unpaired_trainB_filelist', str, '') self._add_option('dataset', 'unpaired_valA_folder', str, '') self._add_option('dataset', 'unpaired_valB_folder', str, '') self._add_option('dataset', 'unpaired_valA_filelist', str, '') self._add_option('dataset', 'unpaired_valB_filelist', str, '') # custom data self._add_option('dataset', 'custom_train_data', dict, {}) self._add_option('dataset', 'custom_val_data', dict, {}) # training config self._add_option('training', 'checkpoints_dir', str, './checkpoints') self._add_option('training', 'log_dir', str, './logs') self._add_option('training', 'use_new_log', bool, False) self._add_option('training', 'continue_train', bool, False) self._add_option('training', 'which_epoch', str, 'latest') self._add_option('training', 'n_epochs', int, 100, check_func=is_greater_than_0) self._add_option('training', 'n_epochs_decay', int, 100, check_func=is_greater_than_0) self._add_option('training', 'save_latest_freq', int, 5000, check_func=is_greater_than_0) self._add_option('training', 'print_freq', int, 200, check_func=is_greater_than_0) self._add_option('training', 'save_epoch_freq', int, 5, check_func=is_greater_than_0) self._add_option('training', 'epoch_as_iter', bool, False) self._add_option('training', 'lr', float, 2e-4, check_func=is_greater_than_0) self._add_option('training', 'lr_policy', str, 'linear', check_func=lambda x: x in ['linear', 'step', 'plateau', 'cosine']) self._add_option('training', 'lr_decay_iters', int, 50, check_func=is_greater_than_0) self._add_option('training', 'DDP', bool, False) self._add_option('training', 'num_nodes', int, 1, check_func=is_greater_than_0) self._add_option('training', 'DDP_address', str, '127.0.0.1') self._add_option('training', 'DDP_port', str, '29700') self._add_option('training', 'find_unused_parameters', bool, False) # a DDP option that allows backward on a subgraph of the model self._add_option('training', 'val_percent', float, 5.0, check_func=is_greater_than_0) # Uses x% of training data to validate self._add_option('training', 'val', bool, True) # perform validation every epoch self._add_option('training', 'save_training_progress', bool, False) # save images to create a training progression video # testing config self._add_option('testing', 'results_dir', str, './results') self._add_option('testing', 'load_size', int, 512, check_func=is_greater_than_0) self._add_option('testing', 'crop_size', int, 512, check_func=is_greater_than_0) self._add_option('testing', 'preprocess', list, ['scale_width']) self._add_option('testing', 'visual_names', list, []) self._add_option('testing', 'num_test', int, 999999, check_func=is_greater_than_0) self._add_option('testing', 'image_format', str, 'jpg', check_func=lambda x: x in ['input', 'jpg', 'jpeg', 'png']) def _add_option(self, group_name, option_name, value_type, default_value, check_func=None): # check name type if not type(group_name) is str or not type(option_name) is str: raise Exception('Type of {} and {} must be str.'.format(group_name, option_name)) # add group if not group_name in self.__config_dict: self.__config_dict[group_name] = {} self.__check_func_dict[group_name] = {} # check type & default value if not type(value_type) is type: try: if value_type.__origin__ is not Union: raise Exception('{} is not a type.'.format(value_type)) except Exception as e: print(e) if not type(default_value) is value_type: try: if value_type.__origin__ is not Union: raise Exception('Type of {} must be {}.'.format(default_value, value_type)) except Exception as e: print(e) # add option to dict if not option_name in self.__config_dict[group_name]: if not check_func is None and not check_func(default_value): raise Exception('Checking {}/{} failed.'.format(group_name, option_name)) self.__config_dict[group_name][option_name] = default_value self.__check_func_dict[group_name][option_name] = check_func else: raise Exception('{} has been already added.'.format(option_name)) def parse_config(self, cfg_file): # load config from yaml file with open(cfg_file, 'r') as f: yaml_config = yaml.safe_load(f) if not type(yaml_config) is dict: raise Exception('Loading yaml file failed.') # replace default options config_dict = copy.deepcopy(self.__config_dict) for group in config_dict: if group in yaml_config: for option in config_dict[group]: if option in yaml_config[group]: value = yaml_config[group][option] if not type(value) is type(config_dict[group][option]): try: # if is not union, it won't have __origin__ attribute. So will throw an error. # The line below is necessary because we check if has __origin__ attribute. if config_dict[group][option].__origin__ is Union: # check to see if type of belongs to a type in the union. if not isinstance(value, config_dict[group][option].__args__): raise Exception('Type of {}/{} must be {}.'.format(group, option, config_dict[group][option].__args__)) except Exception as e: # if the error was thrown, we know there's a type error. print(e) else: check_func = self.__check_func_dict[group][option] if not check_func is None and not check_func(value): raise Exception('Checking {}/{} failed.'.format(group, option)) config_dict[group][option] = value return config_dict