import os from argparse import ArgumentParser, RawDescriptionHelpFormatter from import Mapping import yaml __all__ = ['Config'] class ArgsParser(ArgumentParser): def __init__(self): super(ArgsParser, self).__init__(formatter_class=RawDescriptionHelpFormatter) self.add_argument('-o', '--opt', nargs='*', help='set configuration options') self.add_argument('--local_rank') def parse_args(self, argv=None): args = super(ArgsParser, self).parse_args(argv) assert args.config is not None, 'Please specify --config=configure_file_path.' args.opt = self._parse_opt(args.opt) return args def _parse_opt(self, opts): config = {} if not opts: return config for s in opts: s = s.strip() k, v = s.split('=', 1) if '.' not in k: config[k] = yaml.load(v, Loader=yaml.Loader) else: keys = k.split('.') if keys[0] not in config: config[keys[0]] = {} cur = config[keys[0]] for idx, key in enumerate(keys[1:]): if idx == len(keys) - 2: cur[key] = yaml.load(v, Loader=yaml.Loader) else: cur[key] = {} cur = cur[key] return config class AttrDict(dict): """Single level attribute dict, NOT recursive.""" def __init__(self, **kwargs): super(AttrDict, self).__init__() super(AttrDict, self).update(kwargs) def __getattr__(self, key): if key in self: return self[key] raise AttributeError("object has no attribute '{}'".format(key)) def _merge_dict(config, merge_dct): """Recursive dict merge. Inspired by :meth:``dict.update()``, instead of updating only top-level keys, dict_merge recurses down into dicts nested to an arbitrary depth, updating keys. The ``merge_dct`` is merged into ``dct``. Args: config: dict onto which the merge is executed merge_dct: dct merged into config Returns: dct """ for key, value in merge_dct.items(): sub_keys = key.split('.') key = sub_keys[0] if key in config and len(sub_keys) > 1: _merge_dict(config[key], {'.'.join(sub_keys[1:]): value}) elif key in config and isinstance(config[key], dict) and isinstance( value, Mapping): _merge_dict(config[key], value) else: config[key] = value return config def print_dict(cfg, print_func=print, delimiter=0): """Recursively visualize a dict and indenting acrrording by the relationship of keys.""" for k, v in sorted(cfg.items()): if isinstance(v, dict): print_func('{}{} : '.format(delimiter * ' ', str(k))) print_dict(v, print_func, delimiter + 4) elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict): print_func('{}{} : '.format(delimiter * ' ', str(k))) for value in v: print_dict(value, print_func, delimiter + 4) else: print_func('{}{} : {}'.format(delimiter * ' ', k, v)) class Config(object): def __init__(self, config_path, BASE_KEY='_BASE_'): self.BASE_KEY = BASE_KEY self.cfg = self._load_config_with_base(config_path) def _load_config_with_base(self, file_path): """Load config from file. Args: file_path (str): Path of the config file to be loaded. Returns: global config """ _, ext = os.path.splitext(file_path) assert ext in ['.yml', '.yaml'], 'only support yaml files for now' with open(file_path) as f: file_cfg = yaml.load(f, Loader=yaml.Loader) # NOTE: cfgs outside have higher priority than cfgs in _BASE_ if self.BASE_KEY in file_cfg: all_base_cfg = AttrDict() base_ymls = list(file_cfg[self.BASE_KEY]) for base_yml in base_ymls: if base_yml.startswith('~'): base_yml = os.path.expanduser(base_yml) if not base_yml.startswith('/'): base_yml = os.path.join(os.path.dirname(file_path), base_yml) with open(base_yml) as f: base_cfg = self._load_config_with_base(base_yml) all_base_cfg = _merge_dict(all_base_cfg, base_cfg) del file_cfg[self.BASE_KEY] file_cfg = _merge_dict(all_base_cfg, file_cfg) file_cfg['filename'] = os.path.splitext( os.path.split(file_path)[-1])[0] return file_cfg def merge_dict(self, args): self.cfg = _merge_dict(self.cfg, args) def print_cfg(self, print_func=print): """Recursively visualize a dict and indenting acrrording by the relationship of keys.""" print_func('----------- Config -----------') print_dict(self.cfg, print_func) print_func('---------------------------------------------') def save(self, p, cfg=None): if cfg is None: cfg = self.cfg with open(p, 'w') as f: yaml.dump(dict(cfg), f, default_flow_style=False, sort_keys=False)