File size: 4,803 Bytes
2571f24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
"""Config utilities for yml file."""
import collections
import functools
import os
import re
import yaml
# from imaginaire.utils.distributed import master_only_print as print
class AttrDict(dict):
"""Dict as attribute trick."""
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
for key, value in self.__dict__.items():
if isinstance(value, dict):
self.__dict__[key] = AttrDict(value)
elif isinstance(value, (list, tuple)):
if isinstance(value[0], dict):
self.__dict__[key] = [AttrDict(item) for item in value]
else:
self.__dict__[key] = value
def yaml(self):
"""Convert object to yaml dict and return."""
yaml_dict = {}
for key, value in self.__dict__.items():
if isinstance(value, AttrDict):
yaml_dict[key] = value.yaml()
elif isinstance(value, list):
if isinstance(value[0], AttrDict):
new_l = []
for item in value:
new_l.append(item.yaml())
yaml_dict[key] = new_l
else:
yaml_dict[key] = value
else:
yaml_dict[key] = value
return yaml_dict
def __repr__(self):
"""Print all variables."""
ret_str = []
for key, value in self.__dict__.items():
if isinstance(value, AttrDict):
ret_str.append('{}:'.format(key))
child_ret_str = value.__repr__().split('\n')
for item in child_ret_str:
ret_str.append(' ' + item)
elif isinstance(value, list):
if isinstance(value[0], AttrDict):
ret_str.append('{}:'.format(key))
for item in value:
# Treat as AttrDict above.
child_ret_str = item.__repr__().split('\n')
for item in child_ret_str:
ret_str.append(' ' + item)
else:
ret_str.append('{}: {}'.format(key, value))
else:
ret_str.append('{}: {}'.format(key, value))
return '\n'.join(ret_str)
class Config(AttrDict):
r"""Configuration class. This should include every human specifiable
hyperparameter values for your training."""
def __init__(self, filename=None, verbose=False):
super(Config, self).__init__()
# Update with given configurations.
if os.path.exists(filename):
loader = yaml.SafeLoader
loader.add_implicit_resolver(
u'tag:yaml.org,2002:float',
re.compile(u'''^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$''', re.X),
list(u'-+0123456789.'))
try:
with open(filename, 'r') as f:
cfg_dict = yaml.load(f, Loader=loader)
except EnvironmentError:
print('Please check the file with name of "%s"', filename)
recursive_update(self, cfg_dict)
else:
raise ValueError('Provided config path not existed: %s' % filename)
if verbose:
print(' imaginaire config '.center(80, '-'))
print(self.__repr__())
print(''.center(80, '-'))
def rsetattr(obj, attr, val):
"""Recursively find object and set value"""
pre, _, post = attr.rpartition('.')
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
def rgetattr(obj, attr, *args):
"""Recursively find object and return value"""
def _getattr(obj, attr):
r"""Get attribute."""
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
def recursive_update(d, u):
"""Recursively update AttrDict d with AttrDict u"""
if u is not None:
for key, value in u.items():
if isinstance(value, collections.abc.Mapping):
d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
elif isinstance(value, (list, tuple)):
if len(value) > 0 and isinstance(value[0], dict):
d.__dict__[key] = [AttrDict(item) for item in value]
else:
d.__dict__[key] = value
else:
d.__dict__[key] = value
return d
|