File size: 9,177 Bytes
7e2a2a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import yaml
import copy
from typing import Union

class BaseConfig():

    def __init__(self):
        self.__config_dict = {}
        self.__check_func_dict = {}

        is_greater_than_0 = lambda x: x > 0

        # common config
        self._add_option('common', 'name', str, 'style_master')
        self._add_option('common', 'model', str, 'cycle_gan')
        self._add_option('common', 'phase', str, 'train', check_func=lambda x: x in ['train', 'test'])
        self._add_option('common', 'gpu_ids', list, [0])
        self._add_option('common', 'verbose', bool, False)

        # model config
        self._add_option('model', 'input_nc', int, 3, check_func=is_greater_than_0)
        self._add_option('model', 'output_nc', int, 3, check_func=is_greater_than_0)

        # dataset config
        # common dataset options
        self._add_option('dataset', 'use_absolute_datafile', bool, True)
        self._add_option('dataset', 'batch_size', int, 1, check_func=is_greater_than_0)
        self._add_option('dataset', 'n_threads', int, 4, check_func=is_greater_than_0)
        self._add_option('dataset', 'dataroot', str, './')
        self._add_option('dataset', 'drop_last', bool, False)
        self._add_option('dataset', 'landmark_scale', list, None)
        self._add_option('dataset', 'check_all_data', bool, False)
        self._add_option('dataset', 'accept_data_error', bool, True)  # Upon loading a bad data, if this is true,
                                                                        # dataloader will throw an exception and
                                                                        # load the next good data.
                                                                        # If this is false, process will crash.

        self._add_option('dataset', 'train_data', dict, {})
        self._add_option('dataset', 'val_data', dict, {})

        # paired data config
        self._add_option('dataset', 'paired_trainA_folder', str, '')
        self._add_option('dataset', 'paired_trainB_folder', str, '')
        self._add_option('dataset', 'paired_train_filelist', str, '')
        self._add_option('dataset', 'paired_valA_folder', str, '')
        self._add_option('dataset', 'paired_valB_folder', str, '')
        self._add_option('dataset', 'paired_val_filelist', str, '')

        # unpaired data config
        self._add_option('dataset', 'unpaired_trainA_folder', str, '')
        self._add_option('dataset', 'unpaired_trainB_folder', str, '')
        self._add_option('dataset', 'unpaired_trainA_filelist', str, '')
        self._add_option('dataset', 'unpaired_trainB_filelist', str, '')
        self._add_option('dataset', 'unpaired_valA_folder', str, '')
        self._add_option('dataset', 'unpaired_valB_folder', str, '')
        self._add_option('dataset', 'unpaired_valA_filelist', str, '')
        self._add_option('dataset', 'unpaired_valB_filelist', str, '')

        # custom data
        self._add_option('dataset', 'custom_train_data', dict, {})
        self._add_option('dataset', 'custom_val_data', dict, {})

        # training config
        self._add_option('training', 'checkpoints_dir', str, './checkpoints')
        self._add_option('training', 'log_dir', str, './logs')
        self._add_option('training', 'use_new_log', bool, False)
        self._add_option('training', 'continue_train', bool, False)
        self._add_option('training', 'which_epoch', str, 'latest')
        self._add_option('training', 'n_epochs', int, 100, check_func=is_greater_than_0)
        self._add_option('training', 'n_epochs_decay', int, 100, check_func=is_greater_than_0)
        self._add_option('training', 'save_latest_freq', int, 5000, check_func=is_greater_than_0)
        self._add_option('training', 'print_freq', int, 200, check_func=is_greater_than_0)
        self._add_option('training', 'save_epoch_freq', int, 5, check_func=is_greater_than_0)
        self._add_option('training', 'epoch_as_iter', bool, False)
        self._add_option('training', 'lr', float, 2e-4, check_func=is_greater_than_0)
        self._add_option('training', 'lr_policy', str, 'linear',
            check_func=lambda x: x in ['linear', 'step', 'plateau', 'cosine'])
        self._add_option('training', 'lr_decay_iters', int, 50, check_func=is_greater_than_0)
        self._add_option('training', 'DDP', bool, False)
        self._add_option('training', 'num_nodes', int, 1, check_func=is_greater_than_0)
        self._add_option('training', 'DDP_address', str, '127.0.0.1')
        self._add_option('training', 'DDP_port', str, '29700')
        self._add_option('training', 'find_unused_parameters', bool, False) # a DDP option that allows backward on a subgraph of the model
        self._add_option('training', 'val_percent', float, 5.0, check_func=is_greater_than_0)  # Uses x% of training data to validate
        self._add_option('training', 'val', bool, True)  # perform validation every epoch
        self._add_option('training', 'save_training_progress', bool, False)  # save images to create a training progression video

        # testing config
        self._add_option('testing', 'results_dir', str, './results')
        self._add_option('testing', 'load_size', int, 512, check_func=is_greater_than_0)
        self._add_option('testing', 'crop_size', int, 512, check_func=is_greater_than_0)
        self._add_option('testing', 'preprocess', list, ['scale_width'])
        self._add_option('testing', 'visual_names', list, [])
        self._add_option('testing', 'num_test', int, 999999, check_func=is_greater_than_0)
        self._add_option('testing', 'image_format', str, 'jpg', check_func=lambda x: x in ['input', 'jpg', 'jpeg', 'png'])

    def _add_option(self, group_name, option_name, value_type, default_value, check_func=None):
        # check name type
        if not type(group_name) is str or not type(option_name) is str:
            raise Exception('Type of {} and {} must be str.'.format(group_name, option_name))

        # add group
        if not group_name in self.__config_dict:
            self.__config_dict[group_name] = {}
            self.__check_func_dict[group_name] = {}

        # check type & default value
        if not type(value_type) is type:
            try:
                if value_type.__origin__ is not Union:
                    raise Exception('{} is not a type.'.format(value_type))
            except Exception as e:
                print(e)
        if not type(default_value) is value_type:
            try:
                if value_type.__origin__ is not Union:
                    raise Exception('Type of {} must be {}.'.format(default_value, value_type))
            except Exception as e:
                print(e)

        # add option to dict
        if not option_name in self.__config_dict[group_name]:
            if not check_func is None and not check_func(default_value):
                raise Exception('Checking {}/{} failed.'.format(group_name, option_name))
            self.__config_dict[group_name][option_name] = default_value
            self.__check_func_dict[group_name][option_name] = check_func
        else:
            raise Exception('{} has been already added.'.format(option_name))

    def parse_config(self, cfg_file):
        # load config from yaml file
        with open(cfg_file, 'r') as f:
            yaml_config = yaml.safe_load(f)
        if not type(yaml_config) is dict:
            raise Exception('Loading yaml file failed.')

        # replace default options
        config_dict = copy.deepcopy(self.__config_dict)
        for group in config_dict:
            if group in yaml_config:
                for option in config_dict[group]:
                    if option in yaml_config[group]:
                        value = yaml_config[group][option]
                        if not type(value) is type(config_dict[group][option]):
                            try: # if <config_dict[group][option]> is not union, it won't have __origin__ attribute. So will throw an error.
                                # The line below is necessary because we check if <config_dict[group][option]> has __origin__ attribute.
                                if config_dict[group][option].__origin__ is Union:
                                    # check to see if type of <value> belongs to a type in the union.
                                    if not isinstance(value, config_dict[group][option].__args__):
                                        raise Exception('Type of {}/{} must be {}.'.format(group, option,
                                                                config_dict[group][option].__args__))
                            except Exception as e: # if the error was thrown, we know there's a type error.
                                print(e)
                        else:
                            check_func = self.__check_func_dict[group][option]
                            if not check_func is None and not check_func(value):
                                raise Exception('Checking {}/{} failed.'.format(group, option))
                            config_dict[group][option] = value
        return config_dict