Spaces:
Runtime error
Runtime error
import os | |
import argparse | |
import multiprocessing as mp | |
import torch | |
import importlib | |
import pkgutil | |
import models | |
import training.datasets as data | |
import json, yaml | |
import training.utils as utils | |
from argparse import Namespace | |
from training.utils import get_latest_checkpoint_path | |
class BaseOptions: | |
def __init__(self): | |
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
add_help=False) # TODO - check that help is still displayed | |
# parser.add_argument('--task', type=str, default='training', help="Module from which dataset and model are loaded") | |
parser.add_argument('-d', '--data_dir', type=str, default='data/scaled_features') | |
parser.add_argument('--hparams_file', type=str, default=None) | |
parser.add_argument('--dataset_name', type=str, default="multimodal") | |
parser.add_argument('--base_filenames_file', type=str, default="base_filenames_train.txt") | |
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') | |
parser.add_argument('--batch_size', default=1, type=int) | |
parser.add_argument('--val_batch_size', default=1, type=int, help='batch size for validation data loader') | |
parser.add_argument('--do_validation', action='store_true', help='whether to do validation steps during training') | |
parser.add_argument('--do_testing', action='store_true', help='whether to do evaluation on test set at the end of training') | |
parser.add_argument('--skip_training', action='store_true', help='whether to not do training (only useful when doing just testing)') | |
parser.add_argument('--do_tuning', action='store_true', help='whether to not do the tuning phase (e.g. to tune learning rate)') | |
# parser.add_argument('--augment', type=int, default=0) | |
parser.add_argument('--model', type=str, default="transformer", help="The network model used for beatsaberification") | |
# parser.add_argument('--init_type', type=str, default="normal") | |
# parser.add_argument('--eval', action='store_true', help='use eval mode during validation / test time.') | |
parser.add_argument('--workers', default=0, type=int, help='the number of workers to load the data') | |
# see here for guidelines on setting number of workers: https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813 | |
# and here https://pytorch-lightning.readthedocs.io/_/downloads/en/latest/pdf/ (where they recommend to use accelerator=ddp rather than ddp_spawn) | |
parser.add_argument('--experiment_name', default="experiment_name", type=str) | |
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') | |
parser.add_argument('--fork_processes', action='store_true', help="Set method to create dataloader child processes to fork instead of spawn (could take up more memory)") | |
parser.add_argument('--find_unused_parameters', action='store_true', help="option used with DDP which allows having parameters which are not used for producing the loss. Setting it to false is more efficient, if this option is not needeed") | |
### CHECKPOINTING STUFF | |
parser.add_argument('--checkpoints_dir', default="training/experiments", type=str, help='checkpoint folder') | |
parser.add_argument('--load_weights_only', action='store_true', help='if specified, we load the model weights from the last checkpoint for the specified experiment, WITHOUT loading the optimizer parameters! (allows to continue traning while changing the optimizer)') | |
parser.add_argument('--no_load_hparams', action='store_true', help='if specified, we dont load the saved experiment hparams when doing continue_train') | |
parser.add_argument('--ignore_in_state_dict', type=str, default="", help="substring to match in state dict, to then ignore the corresponding saved weights. Sometimes useful for models where only some part was trained e.g.") | |
parser.add_argument('--only_load_in_state_dict', type=str, default="", help="substring to match in state dict, to then only load the corresponding saved weights. Sometimes useful for models where only some part was trained e.g.") | |
# parser.add_argument('--override_optimizers', action='store_true', help='if specified, we will use the optimizer parameters set by the hparams, even if we are continuing from checkpoint') | |
# maybe could override optimizer using this? https://github.com/PyTorchLightning/pytorch-lightning/issues/3095 but need to know the epoch at which to change it | |
self.parser = parser | |
self.is_train = None | |
self.extra_hparams = ["is_train"] | |
self.opt = None | |
def gather_options(self, parse_args=None): | |
# get the basic options | |
if parse_args is not None: | |
opt, _ = self.parser.parse_known_args(parse_args) | |
else: | |
opt, _ = self.parser.parse_known_args() | |
defaults = vars(self.parser.parse_args([])) | |
if opt.continue_train and not opt.no_load_hparams: | |
logs_path = opt.checkpoints_dir+"/"+opt.experiment_name | |
try: | |
latest_checkpoint_path = get_latest_checkpoint_path(logs_path) | |
except FileNotFoundError: | |
print("checkpoint file not found. Probably trying continue_train on an experiment with no checkpoints") | |
raise | |
hparams_file = latest_checkpoint_path+"/hparams.yaml" | |
print("Loading hparams file ",hparams_file) | |
else: | |
hparams_file = opt.hparams_file | |
if opt.hparams_file is not None: | |
if hparams_file.endswith(".json"): | |
hparams_json = json.loads(jsmin(open(hparams_file).read())) | |
elif hparams_file.endswith(".yaml"): | |
hparams_json = yaml.load(open(hparams_file)) | |
hparams_json2 = {k:v for k,v in hparams_json.items() if (v != False and k in defaults)} | |
self.parser.set_defaults(**hparams_json2) | |
if parse_args is not None: | |
opt, _ = self.parser.parse_known_args(parse_args) | |
else: | |
opt, _ = self.parser.parse_known_args() | |
# load task module and task-specific options | |
# task_name = opt.task | |
# task_options = importlib.import_module("{}.options.task_options".format(task_name)) # must be defined in each task folder | |
# self.parser = argparse.ArgumentParser(parents=[self.parser, task_options.TaskOptions().parser]) | |
# if parse_args is not None: | |
# opt, _ = self.parser.parse_known_args(parse_args) | |
# else: | |
# opt, _ = self.parser.parse_known_args() | |
# modify model-related parser options | |
model_name = opt.model | |
model_option_setter = models.get_option_setter(model_name) | |
parser = model_option_setter(self.parser, opt) | |
if parse_args is not None: | |
opt, _ = parser.parse_known_args(parse_args) # parse again with the new defaults | |
else: | |
opt, _ = self.parser.parse_known_args() | |
# modify dataset-related parser options | |
dataset_name = opt.dataset_name | |
print(dataset_name) | |
dataset_option_setter = data.get_option_setter(dataset_name) | |
parser = dataset_option_setter(parser, self.is_train) | |
#add negation flags | |
defaults = vars(parser.parse_args([])) | |
# import pdb;pdb.set_trace() | |
for key,val in defaults.items(): | |
if val == False: | |
parser.add_argument("--no-"+key, dest=key, action="store_false") | |
if hparams_file is not None: | |
hparams_json2 = {} | |
for k,v in hparams_json.items(): | |
if k in defaults or k in self.extra_hparams: | |
if v!= False: | |
hparams_json2[k] = v | |
else: | |
raise Exception("Hparam "+k+" not recognized!") | |
parser.set_defaults(**hparams_json2) | |
self.parser = parser | |
if parse_args is not None: | |
return parser.parse_args(parse_args) | |
else: | |
return parser.parse_args() | |
def print_options(self, opt): | |
message = '' | |
message += '----------------- Options ---------------\n' | |
for k, v in sorted(vars(opt).items()): | |
comment = '' | |
default = self.parser.get_default(k) | |
if v != default: | |
comment = '\t[default: %s]' % str(default) | |
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) | |
message += '----------------- End -------------------' | |
print(message) | |
# save to the disk | |
expr_dir = os.path.join(opt.checkpoints_dir, opt.experiment_name) | |
utils.mkdirs(expr_dir) | |
file_name = os.path.join(expr_dir, 'opt.txt') | |
file_name_json = os.path.join(expr_dir, 'opt.json') | |
with open(file_name, 'wt') as opt_file: | |
opt_file.write(message) | |
opt_file.write('\n') | |
with open(file_name_json, 'wt') as opt_file: | |
opt_file.write(json.dumps(vars(opt))) | |
def parse(self, parse_args=None): | |
opt = self.gather_options(parse_args=parse_args) | |
opt.is_train = self.is_train # train or test | |
# check options: | |
# if opt.loss_weight: | |
# opt.loss_weight = [float(w) for w in opt.loss_weight.split(',')] | |
# if len(opt.loss_weight) != opt.num_class: | |
# raise ValueError("Given {} weights, when {} classes are expected".format( | |
# len(opt.loss_weight), opt.num_class)) | |
# else: | |
# opt.loss_weight = torch.tensor(opt.loss_weight) | |
opt = {k:v for (k,v) in vars(opt).items() if not callable(v)} | |
opt = Namespace(**opt) | |
self.print_options(opt) | |
# set gpu ids | |
# str_ids = opt.gpu_ids.split(',') | |
# opt.gpu_ids = [] | |
# for str_id in str_ids: | |
# id = int(str_id) | |
# if id >= 0: | |
# opt.gpu_ids.append(id) | |
# if len(opt.gpu_ids) > 0: | |
# torch.cuda.set_device(opt.gpu_ids[0]) | |
# | |
# set multiprocessing | |
#if opt.workers > 0 and not opt.fork_processes: | |
# mp.set_start_method('spawn', force=True) | |
#mp.set_start_method('spawn', force=True) | |
self.opt = opt | |
return self.opt | |