|
|
|
|
|
|
|
|
|
import os |
|
import omegaconf |
|
from omegaconf import OmegaConf |
|
|
|
|
|
def load_config(args=None, config_file=None, overwrite_fairseq=False): |
|
"""TODO (huxu): move fairseq overwrite to another function.""" |
|
if args is not None: |
|
config_file = args.taskconfig |
|
config = recursive_config(config_file) |
|
|
|
if config.dataset.subsampling is not None: |
|
batch_size = config.fairseq.dataset.batch_size // config.dataset.subsampling |
|
print( |
|
"adjusting batch_size to {} due to subsampling {}.".format( |
|
batch_size, config.dataset.subsampling |
|
) |
|
) |
|
config.fairseq.dataset.batch_size = batch_size |
|
|
|
is_test = config.dataset.split is not None and config.dataset.split == "test" |
|
if not is_test: |
|
if ( |
|
config.fairseq.checkpoint is None |
|
or config.fairseq.checkpoint.save_dir is None |
|
): |
|
raise ValueError("fairseq save_dir or save_path must be specified.") |
|
|
|
save_dir = config.fairseq.checkpoint.save_dir |
|
os.makedirs(save_dir, exist_ok=True) |
|
if config.fairseq.common.tensorboard_logdir is not None: |
|
tb_run_dir = suffix_rundir( |
|
save_dir, config.fairseq.common.tensorboard_logdir |
|
) |
|
config.fairseq.common.tensorboard_logdir = tb_run_dir |
|
print( |
|
"update tensorboard_logdir as", config.fairseq.common.tensorboard_logdir |
|
) |
|
os.makedirs(save_dir, exist_ok=True) |
|
OmegaConf.save(config=config, f=os.path.join(save_dir, "config.yaml")) |
|
|
|
if overwrite_fairseq and config.fairseq is not None and args is not None: |
|
|
|
for group in config.fairseq: |
|
for field in config.fairseq[group]: |
|
print("overwrite args." + field, "as", config.fairseq[group][field]) |
|
setattr(args, field, config.fairseq[group][field]) |
|
return config |
|
|
|
|
|
def recursive_config(config_path): |
|
"""allows for stacking of configs in any depth.""" |
|
config = OmegaConf.load(config_path) |
|
if config.includes is not None: |
|
includes = config.includes |
|
config.pop("includes") |
|
base_config = recursive_config(includes) |
|
config = OmegaConf.merge(base_config, config) |
|
return config |
|
|
|
|
|
def suffix_rundir(save_dir, run_dir): |
|
max_id = -1 |
|
for search_dir in os.listdir(save_dir): |
|
if search_dir.startswith(run_dir): |
|
splits = search_dir.split("_") |
|
cur_id = int(splits[1]) if len(splits) > 1 else 0 |
|
max_id = max(max_id, cur_id) |
|
return os.path.join(save_dir, run_dir + "_" + str(max_id + 1)) |
|
|
|
|
|
def overwrite_dir(config, replace, basedir): |
|
for key in config: |
|
if isinstance(config[key], str) and config[key].startswith(basedir): |
|
config[key] = config[key].replace(basedir, replace) |
|
if isinstance(config[key], omegaconf.dictconfig.DictConfig): |
|
overwrite_dir(config[key], replace, basedir) |
|
|