Spaces:
Sleeping
Sleeping
import itertools | |
import operator | |
import os | |
import shutil | |
import time | |
from functools import reduce | |
from typing import List, Union | |
import configargparse | |
import yaml | |
from flatten_dict import flatten, unflatten | |
from loguru import logger | |
from yacs.config import CfgNode as CN | |
from utils.cluster import execute_task_on_cluster | |
from utils.default_hparams import hparams | |
def parse_args(): | |
def add_common_cmdline_args(parser): | |
# for cluster runs | |
parser.add_argument('--cfg', required=True, type=str, help='cfg file path') | |
parser.add_argument('--opts', default=[], nargs='*', help='additional options to update config') | |
parser.add_argument('--cfg_id', type=int, default=0, help='cfg id to run when multiple experiments are spawned') | |
parser.add_argument('--cluster', default=False, action='store_true', help='creates submission files for cluster') | |
parser.add_argument('--bid', type=int, default=10, help='amount of bid for cluster') | |
parser.add_argument('--memory', type=int, default=64000, help='memory amount for cluster') | |
parser.add_argument('--gpu_min_mem', type=int, default=12000, help='minimum amount of GPU memory') | |
parser.add_argument('--gpu_arch', default=['tesla', 'quadro', 'rtx'], | |
nargs='*', help='additional options to update config') | |
parser.add_argument('--num_cpus', type=int, default=8, help='num cpus for cluster') | |
return parser | |
# For Blender main parser | |
arg_formatter = configargparse.ArgumentDefaultsHelpFormatter | |
cfg_parser = configargparse.YAMLConfigFileParser | |
description = 'PyTorch implementation of DECO' | |
parser = configargparse.ArgumentParser(formatter_class=arg_formatter, | |
config_file_parser_class=cfg_parser, | |
description=description, | |
prog='deco') | |
parser = add_common_cmdline_args(parser) | |
args = parser.parse_args() | |
print(args, end='\n\n') | |
return args | |
def get_hparams_defaults(): | |
"""Get a yacs hparamsNode object with default values for my_project.""" | |
# Return a clone so that the defaults will not be altered | |
# This is for the "local variable" use pattern | |
return hparams.clone() | |
def update_hparams(hparams_file): | |
hparams = get_hparams_defaults() | |
hparams.merge_from_file(hparams_file) | |
return hparams.clone() | |
def update_hparams_from_dict(cfg_dict): | |
hparams = get_hparams_defaults() | |
cfg = hparams.load_cfg(str(cfg_dict)) | |
hparams.merge_from_other_cfg(cfg) | |
return hparams.clone() | |
def get_grid_search_configs(config, excluded_keys=[]): | |
""" | |
:param config: dictionary with the configurations | |
:return: The different configurations | |
""" | |
def bool_to_string(x: Union[List[bool], bool]) -> Union[List[str], str]: | |
""" | |
boolean to string conversion | |
:param x: list or bool to be converted | |
:return: string converted thinghat | |
""" | |
if isinstance(x, bool): | |
return [str(x)] | |
for i, j in enumerate(x): | |
x[i] = str(j) | |
return x | |
# exclude from grid search | |
flattened_config_dict = flatten(config, reducer='path') | |
hyper_params = [] | |
for k,v in flattened_config_dict.items(): | |
if isinstance(v,list): | |
if k in excluded_keys: | |
flattened_config_dict[k] = ['+'.join(v)] | |
elif len(v) > 1: | |
hyper_params += [k] | |
if isinstance(v, list) and isinstance(v[0], bool) : | |
flattened_config_dict[k] = bool_to_string(v) | |
if not isinstance(v,list): | |
if isinstance(v, bool): | |
flattened_config_dict[k] = bool_to_string(v) | |
else: | |
flattened_config_dict[k] = [v] | |
keys, values = zip(*flattened_config_dict.items()) | |
experiments = [dict(zip(keys, v)) for v in itertools.product(*values)] | |
for exp_id, exp in enumerate(experiments): | |
for param in excluded_keys: | |
exp[param] = exp[param].strip().split('+') | |
for param_name, param_value in exp.items(): | |
# print(param_name,type(param_value)) | |
if isinstance(param_value, list) and (param_value[0] in ['True', 'False']): | |
exp[param_name] = [True if x == 'True' else False for x in param_value] | |
if param_value in ['True', 'False']: | |
if param_value == 'True': | |
exp[param_name] = True | |
else: | |
exp[param_name] = False | |
experiments[exp_id] = unflatten(exp, splitter='path') | |
return experiments, hyper_params | |
def get_from_dict(dict, keys): | |
return reduce(operator.getitem, keys, dict) | |
def save_dict_to_yaml(obj, filename, mode='w'): | |
with open(filename, mode) as f: | |
yaml.dump(obj, f, default_flow_style=False) | |
def run_grid_search_experiments( | |
args, | |
script='train.py', | |
change_wt_name=True | |
): | |
cfg = yaml.safe_load(open(args.cfg)) | |
# parse config file to split into a list of configs with tuning hyperparameters separated | |
# Also return the names of tuned hyperparameters hyperparameters | |
different_configs, hyperparams = get_grid_search_configs( | |
cfg, | |
excluded_keys=['TRAINING/DATASETS', 'TRAINING/DATASET_MIX_PDF', 'VALIDATION/DATASETS'], | |
) | |
logger.info(f'Grid search hparams: \n {hyperparams}') | |
# The config file may be missing some default values, so we need to add them | |
different_configs = [update_hparams_from_dict(c) for c in different_configs] | |
logger.info(f'======> Number of experiment configurations is {len(different_configs)}') | |
config_to_run = CN(different_configs[args.cfg_id]) | |
if args.cluster: | |
execute_task_on_cluster( | |
script=script, | |
exp_name=config_to_run.EXP_NAME, | |
output_dir=config_to_run.OUTPUT_DIR, | |
condor_dir=config_to_run.CONDOR_DIR, | |
cfg_file=args.cfg, | |
num_exp=len(different_configs), | |
bid_amount=args.bid, | |
num_workers=config_to_run.DATASET.NUM_WORKERS, | |
memory=args.memory, | |
exp_opts=args.opts, | |
gpu_min_mem=args.gpu_min_mem, | |
gpu_arch=args.gpu_arch, | |
) | |
exit() | |
# ==== create logdir using hyperparam settings | |
logtime = time.strftime('%d-%m-%Y_%H-%M-%S') | |
logdir = f'{logtime}_{config_to_run.EXP_NAME}' | |
wt_file = config_to_run.EXP_NAME + '_' | |
for hp in hyperparams: | |
v = get_from_dict(different_configs[args.cfg_id], hp.split('/')) | |
logdir += f'_{hp.replace("/", ".").replace("_", "").lower()}-{v}' | |
wt_file += f'{hp.replace("/", ".").replace("_", "").lower()}-{v}_' | |
logdir = os.path.join(config_to_run.OUTPUT_DIR, logdir) | |
os.makedirs(logdir, exist_ok=True) | |
config_to_run.LOGDIR = logdir | |
wt_file += 'best.pth' | |
wt_path = os.path.join(os.path.dirname(config_to_run.TRAINING.BEST_MODEL_PATH), wt_file) | |
if change_wt_name: config_to_run.TRAINING.BEST_MODEL_PATH = wt_path | |
shutil.copy(src=args.cfg, dst=os.path.join(logdir, 'config.yaml')) | |
# save config | |
save_dict_to_yaml( | |
unflatten(flatten(config_to_run)), | |
os.path.join(config_to_run.LOGDIR, 'config_to_run.yaml') | |
) | |
return config_to_run |