Spaces:
Sleeping
Sleeping
| # ----------------------------------------------------------------------------- | |
| # Functions for parsing args | |
| # ----------------------------------------------------------------------------- | |
| import copy | |
| import os | |
| from ast import literal_eval | |
| import yaml | |
| class CfgNode(dict): | |
| """ | |
| CfgNode represents an internal node in the configuration tree. It's a simple | |
| dict-like container that allows for attribute-based access to keys. | |
| """ | |
| def __init__(self, init_dict=None, key_list=None, new_allowed=False): | |
| # Recursively convert nested dictionaries in init_dict into CfgNodes | |
| init_dict = {} if init_dict is None else init_dict | |
| key_list = [] if key_list is None else key_list | |
| for k, v in init_dict.items(): | |
| if type(v) is dict: | |
| # Convert dict to CfgNode | |
| init_dict[k] = CfgNode(v, key_list=key_list + [k]) | |
| super(CfgNode, self).__init__(init_dict) | |
| def __getattr__(self, name): | |
| if name in self: | |
| return self[name] | |
| else: | |
| raise AttributeError(name) | |
| def __setattr__(self, name, value): | |
| self[name] = value | |
| def __str__(self): | |
| def _indent(s_, num_spaces): | |
| s = s_.split("\n") | |
| if len(s) == 1: | |
| return s_ | |
| first = s.pop(0) | |
| s = [(num_spaces * " ") + line for line in s] | |
| s = "\n".join(s) | |
| s = first + "\n" + s | |
| return s | |
| r = "" | |
| s = [] | |
| for k, v in sorted(self.items()): | |
| seperator = "\n" if isinstance(v, CfgNode) else " " | |
| attr_str = "{}:{}{}".format(str(k), seperator, str(v)) | |
| attr_str = _indent(attr_str, 2) | |
| s.append(attr_str) | |
| r += "\n".join(s) | |
| return r | |
| def __repr__(self): | |
| return "{}({})".format(self.__class__.__name__, | |
| super(CfgNode, self).__repr__()) | |
| def load_cfg_from_cfg_file(file): | |
| cfg = {} | |
| assert os.path.isfile(file) and file.endswith('.yaml'), \ | |
| '{} is not a yaml file'.format(file) | |
| with open(file, 'r') as f: | |
| cfg_from_file = yaml.safe_load(f) | |
| for key in cfg_from_file: | |
| for k, v in cfg_from_file[key].items(): | |
| cfg[k] = v | |
| cfg = CfgNode(cfg) | |
| return cfg | |
| def merge_cfg_from_list(cfg, cfg_list): | |
| new_cfg = copy.deepcopy(cfg) | |
| assert len(cfg_list) % 2 == 0 | |
| for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): | |
| subkey = full_key.split('.')[-1] | |
| assert subkey in cfg, 'Non-existent key: {}'.format(full_key) | |
| value = _decode_cfg_value(v) | |
| value = _check_and_coerce_cfg_value_type(value, cfg[subkey], subkey, | |
| full_key) | |
| setattr(new_cfg, subkey, value) | |
| return new_cfg | |
| def _decode_cfg_value(v): | |
| """Decodes a raw config value (e.g., from a yaml config files or command | |
| line argument) into a Python object. | |
| """ | |
| # All remaining processing is only applied to strings | |
| if not isinstance(v, str): | |
| return v | |
| # Try to interpret `v` as a: | |
| # string, number, tuple, list, dict, boolean, or None | |
| try: | |
| v = literal_eval(v) | |
| # The following two excepts allow v to pass through when it represents a | |
| # string. | |
| # | |
| # Longer explanation: | |
| # The type of v is always a string (before calling literal_eval), but | |
| # sometimes it *represents* a string and other times a data structure, like | |
| # a list. In the case that v represents a string, what we got back from the | |
| # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is | |
| # ok with '"foo"', but will raise a ValueError if given 'foo'. In other | |
| # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval | |
| # will raise a SyntaxError. | |
| except ValueError: | |
| pass | |
| except SyntaxError: | |
| pass | |
| return v | |
| def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): | |
| """Checks that `replacement`, which is intended to replace `original` is of | |
| the right type. The type is correct if it matches exactly or is one of a few | |
| cases in which the type can be easily coerced. | |
| """ | |
| original_type = type(original) | |
| replacement_type = type(replacement) | |
| # The types must match (with some exceptions) | |
| if replacement_type == original_type: | |
| return replacement | |
| # Cast replacement from from_type to to_type if the replacement and original | |
| # types match from_type and to_type | |
| def conditional_cast(from_type, to_type): | |
| if replacement_type == from_type and original_type == to_type: | |
| return True, to_type(replacement) | |
| else: | |
| return False, None | |
| # Conditionally casts | |
| # list <-> tuple | |
| casts = [(tuple, list), (list, tuple)] | |
| # For py2: allow converting from str (bytes) to a unicode string | |
| try: | |
| casts.append((str, unicode)) # noqa: F821 | |
| except Exception: | |
| pass | |
| for (from_type, to_type) in casts: | |
| converted, converted_value = conditional_cast(from_type, to_type) | |
| if converted: | |
| return converted_value | |
| raise ValueError( | |
| "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " | |
| "key: {}".format(original_type, replacement_type, original, | |
| replacement, full_key)) | |