|
import os |
|
from typing import Tuple |
|
from functools import reduce |
|
|
|
from argparse import Namespace |
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
|
|
|
|
|
|
|
|
|
|
def register_resolver(): |
|
OmegaConf.register_new_resolver( |
|
"add", lambda *numbers: sum(numbers) |
|
) |
|
OmegaConf.register_new_resolver( |
|
"multiply", lambda *numbers: reduce(lambda x, y: x * y, numbers) |
|
) |
|
OmegaConf.register_new_resolver( |
|
"sub", lambda n1, n2: n1 - n2 |
|
) |
|
|
|
|
|
def _merge_args_and_config( |
|
cmd_args: Namespace, |
|
yaml_config: DictConfig, |
|
read_only: bool = False |
|
) -> Tuple[DictConfig, DictConfig, DictConfig]: |
|
|
|
cmd_args_dict = vars(cmd_args) |
|
cmd_args_list = [] |
|
for k, v in cmd_args_dict.items(): |
|
cmd_args_list.append(f"{k}={v}") |
|
cmd_args_conf = OmegaConf.from_cli(cmd_args_list) |
|
|
|
|
|
|
|
args_ = OmegaConf.merge(yaml_config, cmd_args_conf) |
|
|
|
if read_only: |
|
OmegaConf.set_readonly(args_, True) |
|
|
|
return args_, cmd_args_conf, yaml_config |
|
|
|
|
|
def merge_configs(args, method_cfg_path): |
|
"""merge command line args (argparse) and config file (OmegaConf)""" |
|
yaml_config_path = os.path.join("./", "config", method_cfg_path) |
|
try: |
|
yaml_config = OmegaConf.load(yaml_config_path) |
|
except FileNotFoundError as e: |
|
print(f"error: {e}") |
|
print(f"input file path: `{method_cfg_path}`") |
|
print(f"config path: `{yaml_config_path}` not found.") |
|
raise FileNotFoundError(e) |
|
return _merge_args_and_config(args, yaml_config, read_only=False) |
|
|
|
|
|
def update_configs(source_args, update_nodes, strict=True, remove_update_nodes=True): |
|
"""update config file (OmegaConf) with dotlist""" |
|
if update_nodes is None: |
|
return source_args |
|
|
|
update_args_list = str(update_nodes).split() |
|
if len(update_args_list) < 1: |
|
return source_args |
|
|
|
|
|
for item in update_args_list: |
|
item_key_ = str(item).split('=')[0] |
|
|
|
|
|
if strict: |
|
|
|
|
|
|
|
|
|
assert not OmegaConf.is_missing(source_args, item_key_), f"the value of {item_key_} is missing." |
|
|
|
|
|
if OmegaConf.select(source_args, item_key_) is None: |
|
source_args.item_key_ = item_key_ |
|
|
|
|
|
update_nodes = OmegaConf.from_dotlist(update_args_list) |
|
merged_args = OmegaConf.merge(source_args, update_nodes) |
|
|
|
|
|
if remove_update_nodes: |
|
OmegaConf.update(merged_args, 'update', '') |
|
return merged_args |
|
|
|
|
|
def update_if_exist(source_args, update_nodes): |
|
"""update config file (OmegaConf) with dotlist""" |
|
if update_nodes is None: |
|
return source_args |
|
|
|
upd_args_list = str(update_nodes).split() |
|
if len(upd_args_list) < 1: |
|
return source_args |
|
|
|
update_args_list = [] |
|
for item in upd_args_list: |
|
item_key_ = str(item).split('=')[0] |
|
|
|
|
|
|
|
|
|
|
|
update_args_list.append(item) |
|
|
|
|
|
if len(update_args_list) < 1: |
|
merged_args = source_args |
|
else: |
|
update_nodes = OmegaConf.from_dotlist(update_args_list) |
|
merged_args = OmegaConf.merge(source_args, update_nodes) |
|
|
|
return merged_args |
|
|
|
|
|
def merge_and_update_config(args): |
|
register_resolver() |
|
|
|
|
|
|
|
if args.config is not None and str(args.config).endswith('.yaml'): |
|
merged_args, cmd_args, yaml_config = merge_configs(args, args.config) |
|
else: |
|
merged_args, cmd_args, yaml_config = args, args, None |
|
|
|
|
|
update_nodes = args.update |
|
final_args = update_configs(merged_args, update_nodes) |
|
|
|
|
|
yaml_config_update = update_if_exist(yaml_config, update_nodes) |
|
cmd_args_update = update_if_exist(cmd_args, update_nodes) |
|
cmd_args_update.update = "" |
|
|
|
final_args.yaml_config = yaml_config_update |
|
final_args.cmd_args = cmd_args_update |
|
|
|
|
|
if final_args.seed < 0: |
|
import random |
|
final_args.seed = random.randint(0, 65535) |
|
|
|
return final_args |
|
|