import os import argparse import multiprocessing as mp import torch import importlib import pkgutil import models import training.datasets as data import json, yaml import training.utils as utils from argparse import Namespace from training.utils import get_latest_checkpoint_path class BaseOptions: def __init__(self): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, add_help=False) # TODO - check that help is still displayed # parser.add_argument('--task', type=str, default='training', help="Module from which dataset and model are loaded") parser.add_argument('-d', '--data_dir', type=str, default='data/scaled_features') parser.add_argument('--hparams_file', type=str, default=None) parser.add_argument('--dataset_name', type=str, default="multimodal") parser.add_argument('--base_filenames_file', type=str, default="base_filenames_train.txt") parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') parser.add_argument('--batch_size', default=1, type=int) parser.add_argument('--val_batch_size', default=1, type=int, help='batch size for validation data loader') parser.add_argument('--do_validation', action='store_true', help='whether to do validation steps during training') parser.add_argument('--do_testing', action='store_true', help='whether to do evaluation on test set at the end of training') parser.add_argument('--skip_training', action='store_true', help='whether to not do training (only useful when doing just testing)') parser.add_argument('--do_tuning', action='store_true', help='whether to not do the tuning phase (e.g. to tune learning rate)') # parser.add_argument('--augment', type=int, default=0) parser.add_argument('--model', type=str, default="transformer", help="The network model used for beatsaberification") # parser.add_argument('--init_type', type=str, default="normal") # parser.add_argument('--eval', action='store_true', help='use eval mode during validation / test time.') parser.add_argument('--workers', default=0, type=int, help='the number of workers to load the data') # see here for guidelines on setting number of workers: https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813 # and here https://pytorch-lightning.readthedocs.io/_/downloads/en/latest/pdf/ (where they recommend to use accelerator=ddp rather than ddp_spawn) parser.add_argument('--experiment_name', default="experiment_name", type=str) parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') parser.add_argument('--fork_processes', action='store_true', help="Set method to create dataloader child processes to fork instead of spawn (could take up more memory)") parser.add_argument('--find_unused_parameters', action='store_true', help="option used with DDP which allows having parameters which are not used for producing the loss. Setting it to false is more efficient, if this option is not needeed") ### CHECKPOINTING STUFF parser.add_argument('--checkpoints_dir', default="training/experiments", type=str, help='checkpoint folder') parser.add_argument('--load_weights_only', action='store_true', help='if specified, we load the model weights from the last checkpoint for the specified experiment, WITHOUT loading the optimizer parameters! (allows to continue traning while changing the optimizer)') parser.add_argument('--no_load_hparams', action='store_true', help='if specified, we dont load the saved experiment hparams when doing continue_train') parser.add_argument('--ignore_in_state_dict', type=str, default="", help="substring to match in state dict, to then ignore the corresponding saved weights. Sometimes useful for models where only some part was trained e.g.") parser.add_argument('--only_load_in_state_dict', type=str, default="", help="substring to match in state dict, to then only load the corresponding saved weights. Sometimes useful for models where only some part was trained e.g.") # parser.add_argument('--override_optimizers', action='store_true', help='if specified, we will use the optimizer parameters set by the hparams, even if we are continuing from checkpoint') # maybe could override optimizer using this? https://github.com/PyTorchLightning/pytorch-lightning/issues/3095 but need to know the epoch at which to change it self.parser = parser self.is_train = None self.extra_hparams = ["is_train"] self.opt = None def gather_options(self, parse_args=None): # get the basic options if parse_args is not None: opt, _ = self.parser.parse_known_args(parse_args) else: opt, _ = self.parser.parse_known_args() defaults = vars(self.parser.parse_args([])) if opt.continue_train and not opt.no_load_hparams: logs_path = opt.checkpoints_dir+"/"+opt.experiment_name try: latest_checkpoint_path = get_latest_checkpoint_path(logs_path) except FileNotFoundError: print("checkpoint file not found. Probably trying continue_train on an experiment with no checkpoints") raise hparams_file = latest_checkpoint_path+"/hparams.yaml" print("Loading hparams file ",hparams_file) else: hparams_file = opt.hparams_file if opt.hparams_file is not None: if hparams_file.endswith(".json"): hparams_json = json.loads(jsmin(open(hparams_file).read())) elif hparams_file.endswith(".yaml"): hparams_json = yaml.load(open(hparams_file)) hparams_json2 = {k:v for k,v in hparams_json.items() if (v != False and k in defaults)} self.parser.set_defaults(**hparams_json2) if parse_args is not None: opt, _ = self.parser.parse_known_args(parse_args) else: opt, _ = self.parser.parse_known_args() # load task module and task-specific options # task_name = opt.task # task_options = importlib.import_module("{}.options.task_options".format(task_name)) # must be defined in each task folder # self.parser = argparse.ArgumentParser(parents=[self.parser, task_options.TaskOptions().parser]) # if parse_args is not None: # opt, _ = self.parser.parse_known_args(parse_args) # else: # opt, _ = self.parser.parse_known_args() # modify model-related parser options model_name = opt.model model_option_setter = models.get_option_setter(model_name) parser = model_option_setter(self.parser, opt) if parse_args is not None: opt, _ = parser.parse_known_args(parse_args) # parse again with the new defaults else: opt, _ = self.parser.parse_known_args() # modify dataset-related parser options dataset_name = opt.dataset_name print(dataset_name) dataset_option_setter = data.get_option_setter(dataset_name) parser = dataset_option_setter(parser, self.is_train) #add negation flags defaults = vars(parser.parse_args([])) # import pdb;pdb.set_trace() for key,val in defaults.items(): if val == False: parser.add_argument("--no-"+key, dest=key, action="store_false") if hparams_file is not None: hparams_json2 = {} for k,v in hparams_json.items(): if k in defaults or k in self.extra_hparams: if v!= False: hparams_json2[k] = v else: raise Exception("Hparam "+k+" not recognized!") parser.set_defaults(**hparams_json2) self.parser = parser if parse_args is not None: return parser.parse_args(parse_args) else: return parser.parse_args() def print_options(self, opt): message = '' message += '----------------- Options ---------------\n' for k, v in sorted(vars(opt).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' print(message) # save to the disk expr_dir = os.path.join(opt.checkpoints_dir, opt.experiment_name) utils.mkdirs(expr_dir) file_name = os.path.join(expr_dir, 'opt.txt') file_name_json = os.path.join(expr_dir, 'opt.json') with open(file_name, 'wt') as opt_file: opt_file.write(message) opt_file.write('\n') with open(file_name_json, 'wt') as opt_file: opt_file.write(json.dumps(vars(opt))) def parse(self, parse_args=None): opt = self.gather_options(parse_args=parse_args) opt.is_train = self.is_train # train or test # check options: # if opt.loss_weight: # opt.loss_weight = [float(w) for w in opt.loss_weight.split(',')] # if len(opt.loss_weight) != opt.num_class: # raise ValueError("Given {} weights, when {} classes are expected".format( # len(opt.loss_weight), opt.num_class)) # else: # opt.loss_weight = torch.tensor(opt.loss_weight) opt = {k:v for (k,v) in vars(opt).items() if not callable(v)} opt = Namespace(**opt) self.print_options(opt) # set gpu ids # str_ids = opt.gpu_ids.split(',') # opt.gpu_ids = [] # for str_id in str_ids: # id = int(str_id) # if id >= 0: # opt.gpu_ids.append(id) # if len(opt.gpu_ids) > 0: # torch.cuda.set_device(opt.gpu_ids[0]) # # set multiprocessing #if opt.workers > 0 and not opt.fork_processes: # mp.set_start_method('spawn', force=True) #mp.set_start_method('spawn', force=True) self.opt = opt return self.opt