MMFS / configs /base_config.py
limoran
add basic files
7e2a2a5
raw
history blame
9.18 kB
import yaml
import copy
from typing import Union
class BaseConfig():
def __init__(self):
self.__config_dict = {}
self.__check_func_dict = {}
is_greater_than_0 = lambda x: x > 0
# common config
self._add_option('common', 'name', str, 'style_master')
self._add_option('common', 'model', str, 'cycle_gan')
self._add_option('common', 'phase', str, 'train', check_func=lambda x: x in ['train', 'test'])
self._add_option('common', 'gpu_ids', list, [0])
self._add_option('common', 'verbose', bool, False)
# model config
self._add_option('model', 'input_nc', int, 3, check_func=is_greater_than_0)
self._add_option('model', 'output_nc', int, 3, check_func=is_greater_than_0)
# dataset config
# common dataset options
self._add_option('dataset', 'use_absolute_datafile', bool, True)
self._add_option('dataset', 'batch_size', int, 1, check_func=is_greater_than_0)
self._add_option('dataset', 'n_threads', int, 4, check_func=is_greater_than_0)
self._add_option('dataset', 'dataroot', str, './')
self._add_option('dataset', 'drop_last', bool, False)
self._add_option('dataset', 'landmark_scale', list, None)
self._add_option('dataset', 'check_all_data', bool, False)
self._add_option('dataset', 'accept_data_error', bool, True) # Upon loading a bad data, if this is true,
# dataloader will throw an exception and
# load the next good data.
# If this is false, process will crash.
self._add_option('dataset', 'train_data', dict, {})
self._add_option('dataset', 'val_data', dict, {})
# paired data config
self._add_option('dataset', 'paired_trainA_folder', str, '')
self._add_option('dataset', 'paired_trainB_folder', str, '')
self._add_option('dataset', 'paired_train_filelist', str, '')
self._add_option('dataset', 'paired_valA_folder', str, '')
self._add_option('dataset', 'paired_valB_folder', str, '')
self._add_option('dataset', 'paired_val_filelist', str, '')
# unpaired data config
self._add_option('dataset', 'unpaired_trainA_folder', str, '')
self._add_option('dataset', 'unpaired_trainB_folder', str, '')
self._add_option('dataset', 'unpaired_trainA_filelist', str, '')
self._add_option('dataset', 'unpaired_trainB_filelist', str, '')
self._add_option('dataset', 'unpaired_valA_folder', str, '')
self._add_option('dataset', 'unpaired_valB_folder', str, '')
self._add_option('dataset', 'unpaired_valA_filelist', str, '')
self._add_option('dataset', 'unpaired_valB_filelist', str, '')
# custom data
self._add_option('dataset', 'custom_train_data', dict, {})
self._add_option('dataset', 'custom_val_data', dict, {})
# training config
self._add_option('training', 'checkpoints_dir', str, './checkpoints')
self._add_option('training', 'log_dir', str, './logs')
self._add_option('training', 'use_new_log', bool, False)
self._add_option('training', 'continue_train', bool, False)
self._add_option('training', 'which_epoch', str, 'latest')
self._add_option('training', 'n_epochs', int, 100, check_func=is_greater_than_0)
self._add_option('training', 'n_epochs_decay', int, 100, check_func=is_greater_than_0)
self._add_option('training', 'save_latest_freq', int, 5000, check_func=is_greater_than_0)
self._add_option('training', 'print_freq', int, 200, check_func=is_greater_than_0)
self._add_option('training', 'save_epoch_freq', int, 5, check_func=is_greater_than_0)
self._add_option('training', 'epoch_as_iter', bool, False)
self._add_option('training', 'lr', float, 2e-4, check_func=is_greater_than_0)
self._add_option('training', 'lr_policy', str, 'linear',
check_func=lambda x: x in ['linear', 'step', 'plateau', 'cosine'])
self._add_option('training', 'lr_decay_iters', int, 50, check_func=is_greater_than_0)
self._add_option('training', 'DDP', bool, False)
self._add_option('training', 'num_nodes', int, 1, check_func=is_greater_than_0)
self._add_option('training', 'DDP_address', str, '127.0.0.1')
self._add_option('training', 'DDP_port', str, '29700')
self._add_option('training', 'find_unused_parameters', bool, False) # a DDP option that allows backward on a subgraph of the model
self._add_option('training', 'val_percent', float, 5.0, check_func=is_greater_than_0) # Uses x% of training data to validate
self._add_option('training', 'val', bool, True) # perform validation every epoch
self._add_option('training', 'save_training_progress', bool, False) # save images to create a training progression video
# testing config
self._add_option('testing', 'results_dir', str, './results')
self._add_option('testing', 'load_size', int, 512, check_func=is_greater_than_0)
self._add_option('testing', 'crop_size', int, 512, check_func=is_greater_than_0)
self._add_option('testing', 'preprocess', list, ['scale_width'])
self._add_option('testing', 'visual_names', list, [])
self._add_option('testing', 'num_test', int, 999999, check_func=is_greater_than_0)
self._add_option('testing', 'image_format', str, 'jpg', check_func=lambda x: x in ['input', 'jpg', 'jpeg', 'png'])
def _add_option(self, group_name, option_name, value_type, default_value, check_func=None):
# check name type
if not type(group_name) is str or not type(option_name) is str:
raise Exception('Type of {} and {} must be str.'.format(group_name, option_name))
# add group
if not group_name in self.__config_dict:
self.__config_dict[group_name] = {}
self.__check_func_dict[group_name] = {}
# check type & default value
if not type(value_type) is type:
try:
if value_type.__origin__ is not Union:
raise Exception('{} is not a type.'.format(value_type))
except Exception as e:
print(e)
if not type(default_value) is value_type:
try:
if value_type.__origin__ is not Union:
raise Exception('Type of {} must be {}.'.format(default_value, value_type))
except Exception as e:
print(e)
# add option to dict
if not option_name in self.__config_dict[group_name]:
if not check_func is None and not check_func(default_value):
raise Exception('Checking {}/{} failed.'.format(group_name, option_name))
self.__config_dict[group_name][option_name] = default_value
self.__check_func_dict[group_name][option_name] = check_func
else:
raise Exception('{} has been already added.'.format(option_name))
def parse_config(self, cfg_file):
# load config from yaml file
with open(cfg_file, 'r') as f:
yaml_config = yaml.safe_load(f)
if not type(yaml_config) is dict:
raise Exception('Loading yaml file failed.')
# replace default options
config_dict = copy.deepcopy(self.__config_dict)
for group in config_dict:
if group in yaml_config:
for option in config_dict[group]:
if option in yaml_config[group]:
value = yaml_config[group][option]
if not type(value) is type(config_dict[group][option]):
try: # if <config_dict[group][option]> is not union, it won't have __origin__ attribute. So will throw an error.
# The line below is necessary because we check if <config_dict[group][option]> has __origin__ attribute.
if config_dict[group][option].__origin__ is Union:
# check to see if type of <value> belongs to a type in the union.
if not isinstance(value, config_dict[group][option].__args__):
raise Exception('Type of {}/{} must be {}.'.format(group, option,
config_dict[group][option].__args__))
except Exception as e: # if the error was thrown, we know there's a type error.
print(e)
else:
check_func = self.__check_func_dict[group][option]
if not check_func is None and not check_func(value):
raise Exception('Checking {}/{} failed.'.format(group, option))
config_dict[group][option] = value
return config_dict