ZHIJI_cv_web_ui / NTED /config.py
zejunyang
update
9667e74
import collections
import functools
import os
import re
import yaml
class AttrDict(dict):
"""Dict as attribute trick."""
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
for key, value in self.__dict__.items():
if isinstance(value, dict):
self.__dict__[key] = AttrDict(value)
elif isinstance(value, (list, tuple)):
if isinstance(value[0], dict):
self.__dict__[key] = [AttrDict(item) for item in value]
else:
self.__dict__[key] = value
def yaml(self):
"""Convert object to yaml dict and return."""
yaml_dict = {}
for key, value in self.__dict__.items():
if isinstance(value, AttrDict):
yaml_dict[key] = value.yaml()
elif isinstance(value, list):
if isinstance(value[0], AttrDict):
new_l = []
for item in value:
new_l.append(item.yaml())
yaml_dict[key] = new_l
else:
yaml_dict[key] = value
else:
yaml_dict[key] = value
return yaml_dict
def __repr__(self):
"""Print all variables."""
ret_str = []
for key, value in self.__dict__.items():
if isinstance(value, AttrDict):
ret_str.append('{}:'.format(key))
child_ret_str = value.__repr__().split('\n')
for item in child_ret_str:
ret_str.append(' ' + item)
elif isinstance(value, list):
if isinstance(value[0], AttrDict):
ret_str.append('{}:'.format(key))
for item in value:
# Treat as AttrDict above.
child_ret_str = item.__repr__().split('\n')
for item in child_ret_str:
ret_str.append(' ' + item)
else:
ret_str.append('{}: {}'.format(key, value))
else:
ret_str.append('{}: {}'.format(key, value))
return '\n'.join(ret_str)
class Config(AttrDict):
r"""Configuration class. This should include every human specifiable
hyperparameter values for your training."""
def __init__(self, filename=None, verbose=False, is_train=True):
super(Config, self).__init__()
# Set default parameters.
# Logging.
large_number = 1000000000
self.snapshot_save_iter = large_number
self.snapshot_save_epoch = large_number
self.snapshot_save_start_iter = 0
self.snapshot_save_start_epoch = 0
self.image_save_iter = large_number
self.eval_epoch = large_number
self.start_eval_epoch = large_number
self.eval_epoch = large_number
self.max_epoch = large_number
self.max_iter = large_number
self.logging_iter = 100
self.image_to_tensorboard=False
self.which_iter = None
self.resume = True
self.checkpoints_dir = 'NTED'
self.name = 'nted_checkpoint.pt'
self.phase = 'train' if is_train else 'test'
# Networks.
self.gen = AttrDict(type='generators.dummy')
self.dis = AttrDict(type='discriminators.dummy')
# Optimizers.
self.gen_optimizer = AttrDict(type='adam',
lr=0.0001,
adam_beta1=0.0,
adam_beta2=0.999,
eps=1e-8,
lr_policy=AttrDict(iteration_mode=False,
type='step',
step_size=large_number,
gamma=1))
self.dis_optimizer = AttrDict(type='adam',
lr=0.0001,
adam_beta1=0.0,
adam_beta2=0.999,
eps=1e-8,
lr_policy=AttrDict(iteration_mode=False,
type='step',
step_size=large_number,
gamma=1))
# Data.
self.data = AttrDict(name='dummy',
type='datasets.images',
num_workers=0)
self.test_data = AttrDict(name='dummy',
type='datasets.images',
num_workers=0,
test=AttrDict(is_lmdb=False,
roots='',
batch_size=1))
self.trainer = AttrDict(
image_to_tensorboard=False,
hparam_to_tensorboard=False)
# Cudnn.
self.cudnn = AttrDict(deterministic=False,
benchmark=True)
# Others.
self.pretrained_weight = ''
self.inference_args = AttrDict()
# Update with given configurations.
assert os.path.exists(filename), 'File {} not exist.'.format(filename)
loader = yaml.SafeLoader
loader.add_implicit_resolver(
u'tag:yaml.org,2002:float',
re.compile(u'''^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$''', re.X),
list(u'-+0123456789.'))
try:
with open(filename, 'r') as f:
cfg_dict = yaml.load(f, Loader=loader)
except EnvironmentError:
print('Please check the file with name of "%s"', filename)
recursive_update(self, cfg_dict)
# Put common opts in both gen and dis.
if 'common' in cfg_dict:
self.common = AttrDict(**cfg_dict['common'])
self.gen.common = self.common
self.dis.common = self.common
if verbose:
print(' config '.center(80, '-'))
print(self.__repr__())
print(''.center(80, '-'))
def rsetattr(obj, attr, val):
"""Recursively find object and set value"""
pre, _, post = attr.rpartition('.')
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
def rgetattr(obj, attr, *args):
"""Recursively find object and return value"""
def _getattr(obj, attr):
r"""Get attribute."""
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
def recursive_update(d, u):
"""Recursively update AttrDict d with AttrDict u"""
for key, value in u.items():
if isinstance(value, collections.abc.Mapping):
d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
elif isinstance(value, (list, tuple)):
if isinstance(value[0], dict):
d.__dict__[key] = [AttrDict(item) for item in value]
else:
d.__dict__[key] = value
else:
d.__dict__[key] = value
return d