DECO / utils /config.py
ac5113's picture
added missing files
b807ddb
import itertools
import operator
import os
import shutil
import time
from functools import reduce
from typing import List, Union
import configargparse
import yaml
from flatten_dict import flatten, unflatten
from loguru import logger
from yacs.config import CfgNode as CN
from utils.cluster import execute_task_on_cluster
from utils.default_hparams import hparams
def parse_args():
def add_common_cmdline_args(parser):
# for cluster runs
parser.add_argument('--cfg', required=True, type=str, help='cfg file path')
parser.add_argument('--opts', default=[], nargs='*', help='additional options to update config')
parser.add_argument('--cfg_id', type=int, default=0, help='cfg id to run when multiple experiments are spawned')
parser.add_argument('--cluster', default=False, action='store_true', help='creates submission files for cluster')
parser.add_argument('--bid', type=int, default=10, help='amount of bid for cluster')
parser.add_argument('--memory', type=int, default=64000, help='memory amount for cluster')
parser.add_argument('--gpu_min_mem', type=int, default=12000, help='minimum amount of GPU memory')
parser.add_argument('--gpu_arch', default=['tesla', 'quadro', 'rtx'],
nargs='*', help='additional options to update config')
parser.add_argument('--num_cpus', type=int, default=8, help='num cpus for cluster')
return parser
# For Blender main parser
arg_formatter = configargparse.ArgumentDefaultsHelpFormatter
cfg_parser = configargparse.YAMLConfigFileParser
description = 'PyTorch implementation of DECO'
parser = configargparse.ArgumentParser(formatter_class=arg_formatter,
config_file_parser_class=cfg_parser,
description=description,
prog='deco')
parser = add_common_cmdline_args(parser)
args = parser.parse_args()
print(args, end='\n\n')
return args
def get_hparams_defaults():
"""Get a yacs hparamsNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
return hparams.clone()
def update_hparams(hparams_file):
hparams = get_hparams_defaults()
hparams.merge_from_file(hparams_file)
return hparams.clone()
def update_hparams_from_dict(cfg_dict):
hparams = get_hparams_defaults()
cfg = hparams.load_cfg(str(cfg_dict))
hparams.merge_from_other_cfg(cfg)
return hparams.clone()
def get_grid_search_configs(config, excluded_keys=[]):
"""
:param config: dictionary with the configurations
:return: The different configurations
"""
def bool_to_string(x: Union[List[bool], bool]) -> Union[List[str], str]:
"""
boolean to string conversion
:param x: list or bool to be converted
:return: string converted thinghat
"""
if isinstance(x, bool):
return [str(x)]
for i, j in enumerate(x):
x[i] = str(j)
return x
# exclude from grid search
flattened_config_dict = flatten(config, reducer='path')
hyper_params = []
for k,v in flattened_config_dict.items():
if isinstance(v,list):
if k in excluded_keys:
flattened_config_dict[k] = ['+'.join(v)]
elif len(v) > 1:
hyper_params += [k]
if isinstance(v, list) and isinstance(v[0], bool) :
flattened_config_dict[k] = bool_to_string(v)
if not isinstance(v,list):
if isinstance(v, bool):
flattened_config_dict[k] = bool_to_string(v)
else:
flattened_config_dict[k] = [v]
keys, values = zip(*flattened_config_dict.items())
experiments = [dict(zip(keys, v)) for v in itertools.product(*values)]
for exp_id, exp in enumerate(experiments):
for param in excluded_keys:
exp[param] = exp[param].strip().split('+')
for param_name, param_value in exp.items():
# print(param_name,type(param_value))
if isinstance(param_value, list) and (param_value[0] in ['True', 'False']):
exp[param_name] = [True if x == 'True' else False for x in param_value]
if param_value in ['True', 'False']:
if param_value == 'True':
exp[param_name] = True
else:
exp[param_name] = False
experiments[exp_id] = unflatten(exp, splitter='path')
return experiments, hyper_params
def get_from_dict(dict, keys):
return reduce(operator.getitem, keys, dict)
def save_dict_to_yaml(obj, filename, mode='w'):
with open(filename, mode) as f:
yaml.dump(obj, f, default_flow_style=False)
def run_grid_search_experiments(
args,
script='train.py',
change_wt_name=True
):
cfg = yaml.safe_load(open(args.cfg))
# parse config file to split into a list of configs with tuning hyperparameters separated
# Also return the names of tuned hyperparameters hyperparameters
different_configs, hyperparams = get_grid_search_configs(
cfg,
excluded_keys=['TRAINING/DATASETS', 'TRAINING/DATASET_MIX_PDF', 'VALIDATION/DATASETS'],
)
logger.info(f'Grid search hparams: \n {hyperparams}')
# The config file may be missing some default values, so we need to add them
different_configs = [update_hparams_from_dict(c) for c in different_configs]
logger.info(f'======> Number of experiment configurations is {len(different_configs)}')
config_to_run = CN(different_configs[args.cfg_id])
if args.cluster:
execute_task_on_cluster(
script=script,
exp_name=config_to_run.EXP_NAME,
output_dir=config_to_run.OUTPUT_DIR,
condor_dir=config_to_run.CONDOR_DIR,
cfg_file=args.cfg,
num_exp=len(different_configs),
bid_amount=args.bid,
num_workers=config_to_run.DATASET.NUM_WORKERS,
memory=args.memory,
exp_opts=args.opts,
gpu_min_mem=args.gpu_min_mem,
gpu_arch=args.gpu_arch,
)
exit()
# ==== create logdir using hyperparam settings
logtime = time.strftime('%d-%m-%Y_%H-%M-%S')
logdir = f'{logtime}_{config_to_run.EXP_NAME}'
wt_file = config_to_run.EXP_NAME + '_'
for hp in hyperparams:
v = get_from_dict(different_configs[args.cfg_id], hp.split('/'))
logdir += f'_{hp.replace("/", ".").replace("_", "").lower()}-{v}'
wt_file += f'{hp.replace("/", ".").replace("_", "").lower()}-{v}_'
logdir = os.path.join(config_to_run.OUTPUT_DIR, logdir)
os.makedirs(logdir, exist_ok=True)
config_to_run.LOGDIR = logdir
wt_file += 'best.pth'
wt_path = os.path.join(os.path.dirname(config_to_run.TRAINING.BEST_MODEL_PATH), wt_file)
if change_wt_name: config_to_run.TRAINING.BEST_MODEL_PATH = wt_path
shutil.copy(src=args.cfg, dst=os.path.join(logdir, 'config.yaml'))
# save config
save_dict_to_yaml(
unflatten(flatten(config_to_run)),
os.path.join(config_to_run.LOGDIR, 'config_to_run.yaml')
)
return config_to_run