DECO / utils /config.py
ac5113's picture
added files
99a05f0
raw
history blame
7.32 kB
import itertools
import operator
import os
import shutil
import time
from functools import reduce
from typing import List, Union
import configargparse
import yaml
from flatten_dict import flatten, unflatten
from loguru import logger
from yacs.config import CfgNode as CN
from utils.cluster import execute_task_on_cluster
from utils.default_hparams import hparams
def parse_args():
def add_common_cmdline_args(parser):
# for cluster runs
parser.add_argument('--cfg', required=True, type=str, help='cfg file path')
parser.add_argument('--opts', default=[], nargs='*', help='additional options to update config')
parser.add_argument('--cfg_id', type=int, default=0, help='cfg id to run when multiple experiments are spawned')
parser.add_argument('--cluster', default=False, action='store_true', help='creates submission files for cluster')
parser.add_argument('--bid', type=int, default=10, help='amount of bid for cluster')
parser.add_argument('--memory', type=int, default=64000, help='memory amount for cluster')
parser.add_argument('--gpu_min_mem', type=int, default=12000, help='minimum amount of GPU memory')
parser.add_argument('--gpu_arch', default=['tesla', 'quadro', 'rtx'],
nargs='*', help='additional options to update config')
parser.add_argument('--num_cpus', type=int, default=8, help='num cpus for cluster')
return parser
# For Blender main parser
arg_formatter = configargparse.ArgumentDefaultsHelpFormatter
cfg_parser = configargparse.YAMLConfigFileParser
description = 'PyTorch implementation of DECO'
parser = configargparse.ArgumentParser(formatter_class=arg_formatter,
config_file_parser_class=cfg_parser,
description=description,
prog='deco')
parser = add_common_cmdline_args(parser)
args = parser.parse_args()
print(args, end='\n\n')
return args
def get_hparams_defaults():
"""Get a yacs hparamsNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
return hparams.clone()
def update_hparams(hparams_file):
hparams = get_hparams_defaults()
hparams.merge_from_file(hparams_file)
return hparams.clone()
def update_hparams_from_dict(cfg_dict):
hparams = get_hparams_defaults()
cfg = hparams.load_cfg(str(cfg_dict))
hparams.merge_from_other_cfg(cfg)
return hparams.clone()
def get_grid_search_configs(config, excluded_keys=[]):
"""
:param config: dictionary with the configurations
:return: The different configurations
"""
def bool_to_string(x: Union[List[bool], bool]) -> Union[List[str], str]:
"""
boolean to string conversion
:param x: list or bool to be converted
:return: string converted thinghat
"""
if isinstance(x, bool):
return [str(x)]
for i, j in enumerate(x):
x[i] = str(j)
return x
# exclude from grid search
flattened_config_dict = flatten(config, reducer='path')
hyper_params = []
for k,v in flattened_config_dict.items():
if isinstance(v,list):
if k in excluded_keys:
flattened_config_dict[k] = ['+'.join(v)]
elif len(v) > 1:
hyper_params += [k]
if isinstance(v, list) and isinstance(v[0], bool) :
flattened_config_dict[k] = bool_to_string(v)
if not isinstance(v,list):
if isinstance(v, bool):
flattened_config_dict[k] = bool_to_string(v)
else:
flattened_config_dict[k] = [v]
keys, values = zip(*flattened_config_dict.items())
experiments = [dict(zip(keys, v)) for v in itertools.product(*values)]
for exp_id, exp in enumerate(experiments):
for param in excluded_keys:
exp[param] = exp[param].strip().split('+')
for param_name, param_value in exp.items():
# print(param_name,type(param_value))
if isinstance(param_value, list) and (param_value[0] in ['True', 'False']):
exp[param_name] = [True if x == 'True' else False for x in param_value]
if param_value in ['True', 'False']:
if param_value == 'True':
exp[param_name] = True
else:
exp[param_name] = False
experiments[exp_id] = unflatten(exp, splitter='path')
return experiments, hyper_params
def get_from_dict(dict, keys):
return reduce(operator.getitem, keys, dict)
def save_dict_to_yaml(obj, filename, mode='w'):
with open(filename, mode) as f:
yaml.dump(obj, f, default_flow_style=False)
def run_grid_search_experiments(
args,
script='train.py',
change_wt_name=True
):
cfg = yaml.safe_load(open(args.cfg))
# parse config file to split into a list of configs with tuning hyperparameters separated
# Also return the names of tuned hyperparameters hyperparameters
different_configs, hyperparams = get_grid_search_configs(
cfg,
excluded_keys=['TRAINING/DATASETS', 'TRAINING/DATASET_MIX_PDF', 'VALIDATION/DATASETS'],
)
logger.info(f'Grid search hparams: \n {hyperparams}')
# The config file may be missing some default values, so we need to add them
different_configs = [update_hparams_from_dict(c) for c in different_configs]
logger.info(f'======> Number of experiment configurations is {len(different_configs)}')
config_to_run = CN(different_configs[args.cfg_id])
if args.cluster:
execute_task_on_cluster(
script=script,
exp_name=config_to_run.EXP_NAME,
output_dir=config_to_run.OUTPUT_DIR,
condor_dir=config_to_run.CONDOR_DIR,
cfg_file=args.cfg,
num_exp=len(different_configs),
bid_amount=args.bid,
num_workers=config_to_run.DATASET.NUM_WORKERS,
memory=args.memory,
exp_opts=args.opts,
gpu_min_mem=args.gpu_min_mem,
gpu_arch=args.gpu_arch,
)
exit()
# ==== create logdir using hyperparam settings
logtime = time.strftime('%d-%m-%Y_%H-%M-%S')
logdir = f'{logtime}_{config_to_run.EXP_NAME}'
wt_file = config_to_run.EXP_NAME + '_'
for hp in hyperparams:
v = get_from_dict(different_configs[args.cfg_id], hp.split('/'))
logdir += f'_{hp.replace("/", ".").replace("_", "").lower()}-{v}'
wt_file += f'{hp.replace("/", ".").replace("_", "").lower()}-{v}_'
logdir = os.path.join(config_to_run.OUTPUT_DIR, logdir)
os.makedirs(logdir, exist_ok=True)
config_to_run.LOGDIR = logdir
wt_file += 'best.pth'
wt_path = os.path.join(os.path.dirname(config_to_run.TRAINING.BEST_MODEL_PATH), wt_file)
if change_wt_name: config_to_run.TRAINING.BEST_MODEL_PATH = wt_path
shutil.copy(src=args.cfg, dst=os.path.join(logdir, 'config.yaml'))
# save config
save_dict_to_yaml(
unflatten(flatten(config_to_run)),
os.path.join(config_to_run.LOGDIR, 'config_to_run.yaml')
)
return config_to_run