File size: 10,427 Bytes
34501b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import os
import argparse
import multiprocessing as mp
import torch
import importlib
import pkgutil
import models
import training.datasets as data
import json, yaml
import training.utils as utils
from argparse import Namespace
from training.utils import get_latest_checkpoint_path

class BaseOptions:
    def __init__(self):
        parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                         add_help=False)  # TODO - check that help is still displayed
        # parser.add_argument('--task', type=str, default='training', help="Module from which dataset and model are loaded")
        parser.add_argument('-d', '--data_dir', type=str, default='data/scaled_features')
        parser.add_argument('--hparams_file', type=str, default=None)
        parser.add_argument('--dataset_name', type=str, default="multimodal")
        parser.add_argument('--base_filenames_file', type=str, default="base_filenames_train.txt")
        parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
        parser.add_argument('--batch_size', default=1, type=int)
        parser.add_argument('--val_batch_size', default=1, type=int, help='batch size for validation data loader')
        parser.add_argument('--do_validation', action='store_true', help='whether to do validation steps during training')
        parser.add_argument('--do_testing', action='store_true', help='whether to do evaluation on test set at the end of training')
        parser.add_argument('--skip_training', action='store_true', help='whether to not do training (only useful when doing just testing)')
        parser.add_argument('--do_tuning', action='store_true', help='whether to not do the tuning phase (e.g. to tune learning rate)')
        # parser.add_argument('--augment', type=int, default=0)
        parser.add_argument('--model', type=str, default="transformer", help="The network model used for beatsaberification")
        # parser.add_argument('--init_type', type=str, default="normal")
        # parser.add_argument('--eval', action='store_true', help='use eval mode during validation / test time.')
        parser.add_argument('--workers', default=0, type=int, help='the number of workers to load the data')
        # see here for guidelines on setting number of workers: https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813
        # and here https://pytorch-lightning.readthedocs.io/_/downloads/en/latest/pdf/ (where they recommend to use accelerator=ddp rather than ddp_spawn)
        parser.add_argument('--experiment_name', default="experiment_name", type=str)
        parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
        parser.add_argument('--fork_processes', action='store_true', help="Set method to create dataloader child processes to fork instead of spawn (could take up more memory)")
        parser.add_argument('--find_unused_parameters', action='store_true', help="option used with DDP which allows having parameters which are not used for producing the loss. Setting it to false is more efficient, if this option is not needeed")
        ### CHECKPOINTING STUFF
        parser.add_argument('--checkpoints_dir', default="training/experiments", type=str, help='checkpoint folder')
        parser.add_argument('--load_weights_only', action='store_true', help='if specified, we load the model weights from the last checkpoint for the specified experiment, WITHOUT loading the optimizer parameters! (allows to continue traning while changing the optimizer)')
        parser.add_argument('--no_load_hparams', action='store_true', help='if specified, we dont load the saved experiment hparams when doing continue_train')
        parser.add_argument('--ignore_in_state_dict', type=str, default="", help="substring to match in state dict, to then ignore the corresponding saved weights. Sometimes useful for models where only some part was trained e.g.")
        parser.add_argument('--only_load_in_state_dict', type=str, default="", help="substring to match in state dict, to then only load the corresponding saved weights. Sometimes useful for models where only some part was trained e.g.")
        # parser.add_argument('--override_optimizers', action='store_true', help='if specified, we will use the optimizer parameters set by the hparams, even if we are continuing from checkpoint')
        # maybe could override optimizer using this? https://github.com/PyTorchLightning/pytorch-lightning/issues/3095 but need to know the epoch at which to change it

        self.parser = parser
        self.is_train = None
        self.extra_hparams = ["is_train"]
        self.opt = None

    def gather_options(self, parse_args=None):
        # get the basic options
        if parse_args is not None:
            opt, _ = self.parser.parse_known_args(parse_args)
        else:
            opt, _ = self.parser.parse_known_args()

        defaults = vars(self.parser.parse_args([]))

        if opt.continue_train and not opt.no_load_hparams:
            logs_path = opt.checkpoints_dir+"/"+opt.experiment_name
            try:
                latest_checkpoint_path = get_latest_checkpoint_path(logs_path)
            except FileNotFoundError:
                print("checkpoint file not found. Probably trying continue_train on an experiment with no checkpoints")
                raise
            hparams_file = latest_checkpoint_path+"/hparams.yaml"
            print("Loading hparams file ",hparams_file)
        else:
            hparams_file = opt.hparams_file

        if opt.hparams_file is not None:
            if hparams_file.endswith(".json"):
                hparams_json = json.loads(jsmin(open(hparams_file).read()))
            elif hparams_file.endswith(".yaml"):
                hparams_json = yaml.load(open(hparams_file))
            hparams_json2 = {k:v for k,v in hparams_json.items() if (v != False and k in defaults)}
            self.parser.set_defaults(**hparams_json2)

        if parse_args is not None:
            opt, _ = self.parser.parse_known_args(parse_args)
        else:
            opt, _ = self.parser.parse_known_args()

        # load task module and task-specific options
        # task_name = opt.task
        # task_options = importlib.import_module("{}.options.task_options".format(task_name))  # must be defined in each task folder
        # self.parser = argparse.ArgumentParser(parents=[self.parser, task_options.TaskOptions().parser])
        # if parse_args is not None:
        #     opt, _ = self.parser.parse_known_args(parse_args)
        # else:
        #     opt, _ = self.parser.parse_known_args()

        # modify model-related parser options
        model_name = opt.model
        model_option_setter = models.get_option_setter(model_name)
        parser = model_option_setter(self.parser, opt)
        if parse_args is not None:
            opt, _ = parser.parse_known_args(parse_args)  # parse again with the new defaults
        else:
            opt, _ = self.parser.parse_known_args()

        # modify dataset-related parser options
        dataset_name = opt.dataset_name
        print(dataset_name)
        dataset_option_setter = data.get_option_setter(dataset_name)
        parser = dataset_option_setter(parser, self.is_train)

        #add negation flags
        defaults = vars(parser.parse_args([]))
        # import pdb;pdb.set_trace()
        for key,val in defaults.items():
            if val == False:
                parser.add_argument("--no-"+key, dest=key, action="store_false")

        if hparams_file is not None:
            hparams_json2 = {}
            for k,v in hparams_json.items():
                if k in defaults or k in self.extra_hparams:
                    if v!= False:
                        hparams_json2[k] = v
                else:
                    raise Exception("Hparam "+k+" not recognized!")
            parser.set_defaults(**hparams_json2)

        self.parser = parser
        if parse_args is not None:
            return parser.parse_args(parse_args)
        else:
            return parser.parse_args()

    def print_options(self, opt):
        message = ''
        message += '----------------- Options ---------------\n'
        for k, v in sorted(vars(opt).items()):
            comment = ''
            default = self.parser.get_default(k)
            if v != default:
                comment = '\t[default: %s]' % str(default)
            message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
        message += '----------------- End -------------------'
        print(message)

        # save to the disk
        expr_dir = os.path.join(opt.checkpoints_dir, opt.experiment_name)
        utils.mkdirs(expr_dir)
        file_name = os.path.join(expr_dir, 'opt.txt')
        file_name_json = os.path.join(expr_dir, 'opt.json')
        with open(file_name, 'wt') as opt_file:
            opt_file.write(message)
            opt_file.write('\n')
        with open(file_name_json, 'wt') as opt_file:
            opt_file.write(json.dumps(vars(opt)))

    def parse(self, parse_args=None):

        opt = self.gather_options(parse_args=parse_args)
        opt.is_train = self.is_train   # train or test

        # check options:
        # if opt.loss_weight:
        #     opt.loss_weight = [float(w) for w in opt.loss_weight.split(',')]
        #     if len(opt.loss_weight) != opt.num_class:
        #         raise ValueError("Given {} weights, when {} classes are expected".format(
        #             len(opt.loss_weight), opt.num_class))
        #     else:
        #         opt.loss_weight = torch.tensor(opt.loss_weight)

        opt = {k:v for (k,v) in vars(opt).items() if not callable(v)}
        opt = Namespace(**opt)

        self.print_options(opt)
        # set gpu ids
        # str_ids = opt.gpu_ids.split(',')
        # opt.gpu_ids = []
        # for str_id in str_ids:
        #     id = int(str_id)
        #     if id >= 0:
        #         opt.gpu_ids.append(id)
        # if len(opt.gpu_ids) > 0:
        #     torch.cuda.set_device(opt.gpu_ids[0])
        #
        # set multiprocessing
        #if opt.workers > 0 and not opt.fork_processes:
        #    mp.set_start_method('spawn', force=True)
        #mp.set_start_method('spawn', force=True)

        self.opt = opt
        return self.opt