Spaces:
Sleeping
Sleeping
File size: 9,177 Bytes
7e2a2a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import yaml
import copy
from typing import Union
class BaseConfig():
def __init__(self):
self.__config_dict = {}
self.__check_func_dict = {}
is_greater_than_0 = lambda x: x > 0
# common config
self._add_option('common', 'name', str, 'style_master')
self._add_option('common', 'model', str, 'cycle_gan')
self._add_option('common', 'phase', str, 'train', check_func=lambda x: x in ['train', 'test'])
self._add_option('common', 'gpu_ids', list, [0])
self._add_option('common', 'verbose', bool, False)
# model config
self._add_option('model', 'input_nc', int, 3, check_func=is_greater_than_0)
self._add_option('model', 'output_nc', int, 3, check_func=is_greater_than_0)
# dataset config
# common dataset options
self._add_option('dataset', 'use_absolute_datafile', bool, True)
self._add_option('dataset', 'batch_size', int, 1, check_func=is_greater_than_0)
self._add_option('dataset', 'n_threads', int, 4, check_func=is_greater_than_0)
self._add_option('dataset', 'dataroot', str, './')
self._add_option('dataset', 'drop_last', bool, False)
self._add_option('dataset', 'landmark_scale', list, None)
self._add_option('dataset', 'check_all_data', bool, False)
self._add_option('dataset', 'accept_data_error', bool, True) # Upon loading a bad data, if this is true,
# dataloader will throw an exception and
# load the next good data.
# If this is false, process will crash.
self._add_option('dataset', 'train_data', dict, {})
self._add_option('dataset', 'val_data', dict, {})
# paired data config
self._add_option('dataset', 'paired_trainA_folder', str, '')
self._add_option('dataset', 'paired_trainB_folder', str, '')
self._add_option('dataset', 'paired_train_filelist', str, '')
self._add_option('dataset', 'paired_valA_folder', str, '')
self._add_option('dataset', 'paired_valB_folder', str, '')
self._add_option('dataset', 'paired_val_filelist', str, '')
# unpaired data config
self._add_option('dataset', 'unpaired_trainA_folder', str, '')
self._add_option('dataset', 'unpaired_trainB_folder', str, '')
self._add_option('dataset', 'unpaired_trainA_filelist', str, '')
self._add_option('dataset', 'unpaired_trainB_filelist', str, '')
self._add_option('dataset', 'unpaired_valA_folder', str, '')
self._add_option('dataset', 'unpaired_valB_folder', str, '')
self._add_option('dataset', 'unpaired_valA_filelist', str, '')
self._add_option('dataset', 'unpaired_valB_filelist', str, '')
# custom data
self._add_option('dataset', 'custom_train_data', dict, {})
self._add_option('dataset', 'custom_val_data', dict, {})
# training config
self._add_option('training', 'checkpoints_dir', str, './checkpoints')
self._add_option('training', 'log_dir', str, './logs')
self._add_option('training', 'use_new_log', bool, False)
self._add_option('training', 'continue_train', bool, False)
self._add_option('training', 'which_epoch', str, 'latest')
self._add_option('training', 'n_epochs', int, 100, check_func=is_greater_than_0)
self._add_option('training', 'n_epochs_decay', int, 100, check_func=is_greater_than_0)
self._add_option('training', 'save_latest_freq', int, 5000, check_func=is_greater_than_0)
self._add_option('training', 'print_freq', int, 200, check_func=is_greater_than_0)
self._add_option('training', 'save_epoch_freq', int, 5, check_func=is_greater_than_0)
self._add_option('training', 'epoch_as_iter', bool, False)
self._add_option('training', 'lr', float, 2e-4, check_func=is_greater_than_0)
self._add_option('training', 'lr_policy', str, 'linear',
check_func=lambda x: x in ['linear', 'step', 'plateau', 'cosine'])
self._add_option('training', 'lr_decay_iters', int, 50, check_func=is_greater_than_0)
self._add_option('training', 'DDP', bool, False)
self._add_option('training', 'num_nodes', int, 1, check_func=is_greater_than_0)
self._add_option('training', 'DDP_address', str, '127.0.0.1')
self._add_option('training', 'DDP_port', str, '29700')
self._add_option('training', 'find_unused_parameters', bool, False) # a DDP option that allows backward on a subgraph of the model
self._add_option('training', 'val_percent', float, 5.0, check_func=is_greater_than_0) # Uses x% of training data to validate
self._add_option('training', 'val', bool, True) # perform validation every epoch
self._add_option('training', 'save_training_progress', bool, False) # save images to create a training progression video
# testing config
self._add_option('testing', 'results_dir', str, './results')
self._add_option('testing', 'load_size', int, 512, check_func=is_greater_than_0)
self._add_option('testing', 'crop_size', int, 512, check_func=is_greater_than_0)
self._add_option('testing', 'preprocess', list, ['scale_width'])
self._add_option('testing', 'visual_names', list, [])
self._add_option('testing', 'num_test', int, 999999, check_func=is_greater_than_0)
self._add_option('testing', 'image_format', str, 'jpg', check_func=lambda x: x in ['input', 'jpg', 'jpeg', 'png'])
def _add_option(self, group_name, option_name, value_type, default_value, check_func=None):
# check name type
if not type(group_name) is str or not type(option_name) is str:
raise Exception('Type of {} and {} must be str.'.format(group_name, option_name))
# add group
if not group_name in self.__config_dict:
self.__config_dict[group_name] = {}
self.__check_func_dict[group_name] = {}
# check type & default value
if not type(value_type) is type:
try:
if value_type.__origin__ is not Union:
raise Exception('{} is not a type.'.format(value_type))
except Exception as e:
print(e)
if not type(default_value) is value_type:
try:
if value_type.__origin__ is not Union:
raise Exception('Type of {} must be {}.'.format(default_value, value_type))
except Exception as e:
print(e)
# add option to dict
if not option_name in self.__config_dict[group_name]:
if not check_func is None and not check_func(default_value):
raise Exception('Checking {}/{} failed.'.format(group_name, option_name))
self.__config_dict[group_name][option_name] = default_value
self.__check_func_dict[group_name][option_name] = check_func
else:
raise Exception('{} has been already added.'.format(option_name))
def parse_config(self, cfg_file):
# load config from yaml file
with open(cfg_file, 'r') as f:
yaml_config = yaml.safe_load(f)
if not type(yaml_config) is dict:
raise Exception('Loading yaml file failed.')
# replace default options
config_dict = copy.deepcopy(self.__config_dict)
for group in config_dict:
if group in yaml_config:
for option in config_dict[group]:
if option in yaml_config[group]:
value = yaml_config[group][option]
if not type(value) is type(config_dict[group][option]):
try: # if <config_dict[group][option]> is not union, it won't have __origin__ attribute. So will throw an error.
# The line below is necessary because we check if <config_dict[group][option]> has __origin__ attribute.
if config_dict[group][option].__origin__ is Union:
# check to see if type of <value> belongs to a type in the union.
if not isinstance(value, config_dict[group][option].__args__):
raise Exception('Type of {}/{} must be {}.'.format(group, option,
config_dict[group][option].__args__))
except Exception as e: # if the error was thrown, we know there's a type error.
print(e)
else:
check_func = self.__check_func_dict[group][option]
if not check_func is None and not check_func(value):
raise Exception('Checking {}/{} failed.'.format(group, option))
config_dict[group][option] = value
return config_dict
|