venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
"""Config utilities for yml file."""
import collections
import functools
import os
import re
import yaml
from imaginaire.utils.distributed import master_only_print as print
DEBUG = False
USE_JIT = False
class AttrDict(dict):
"""Dict as attribute trick."""
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
for key, value in self.__dict__.items():
if isinstance(value, dict):
self.__dict__[key] = AttrDict(value)
elif isinstance(value, (list, tuple)):
if isinstance(value[0], dict):
self.__dict__[key] = [AttrDict(item) for item in value]
else:
self.__dict__[key] = value
def yaml(self):
"""Convert object to yaml dict and return."""
yaml_dict = {}
for key, value in self.__dict__.items():
if isinstance(value, AttrDict):
yaml_dict[key] = value.yaml()
elif isinstance(value, list):
if isinstance(value[0], AttrDict):
new_l = []
for item in value:
new_l.append(item.yaml())
yaml_dict[key] = new_l
else:
yaml_dict[key] = value
else:
yaml_dict[key] = value
return yaml_dict
def __repr__(self):
"""Print all variables."""
ret_str = []
for key, value in self.__dict__.items():
if isinstance(value, AttrDict):
ret_str.append('{}:'.format(key))
child_ret_str = value.__repr__().split('\n')
for item in child_ret_str:
ret_str.append(' ' + item)
elif isinstance(value, list):
if isinstance(value[0], AttrDict):
ret_str.append('{}:'.format(key))
for item in value:
# Treat as AttrDict above.
child_ret_str = item.__repr__().split('\n')
for item in child_ret_str:
ret_str.append(' ' + item)
else:
ret_str.append('{}: {}'.format(key, value))
else:
ret_str.append('{}: {}'.format(key, value))
return '\n'.join(ret_str)
class Config(AttrDict):
r"""Configuration class. This should include every human specifiable
hyperparameter values for your training."""
def __init__(self, filename=None, verbose=False):
super(Config, self).__init__()
self.source_filename = filename
# Set default parameters.
# Logging.
large_number = 1000000000
self.snapshot_save_iter = large_number
self.snapshot_save_epoch = large_number
self.metrics_iter = None
self.metrics_epoch = None
self.snapshot_save_start_iter = 0
self.snapshot_save_start_epoch = 0
self.image_save_iter = large_number
self.image_display_iter = large_number
self.max_epoch = large_number
self.max_iter = large_number
self.logging_iter = 100
self.speed_benchmark = False
# Trainer.
self.trainer = AttrDict(
model_average_config=AttrDict(enabled=False,
beta=0.9999,
start_iteration=1000,
num_batch_norm_estimation_iterations=30,
remove_sn=True),
# model_average=False,
# model_average_beta=0.9999,
# model_average_start_iteration=1000,
# model_average_batch_norm_estimation_iteration=30,
# model_average_remove_sn=True,
image_to_tensorboard=False,
hparam_to_tensorboard=False,
distributed_data_parallel='pytorch',
distributed_data_parallel_params=AttrDict(
find_unused_parameters=False),
delay_allreduce=True,
gan_relativistic=False,
gen_step=1,
dis_step=1,
gan_decay_k=1.,
gan_min_k=1.,
gan_separate_topk=False,
aug_policy='',
channels_last=False,
strict_resume=True,
amp_gp=False,
amp_config=AttrDict(init_scale=65536.0,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
enabled=False))
# Networks.
self.gen = AttrDict(type='imaginaire.generators.dummy')
self.dis = AttrDict(type='imaginaire.discriminators.dummy')
# Optimizers.
self.gen_opt = AttrDict(type='adam',
fused_opt=False,
lr=0.0001,
adam_beta1=0.0,
adam_beta2=0.999,
eps=1e-8,
lr_policy=AttrDict(iteration_mode=False,
type='step',
step_size=large_number,
gamma=1))
self.dis_opt = AttrDict(type='adam',
fused_opt=False,
lr=0.0001,
adam_beta1=0.0,
adam_beta2=0.999,
eps=1e-8,
lr_policy=AttrDict(iteration_mode=False,
type='step',
step_size=large_number,
gamma=1))
# Data.
self.data = AttrDict(name='dummy',
type='imaginaire.datasets.images',
num_workers=0)
self.test_data = AttrDict(name='dummy',
type='imaginaire.datasets.images',
num_workers=0,
test=AttrDict(is_lmdb=False,
roots='',
batch_size=1))
# Cudnn.
self.cudnn = AttrDict(deterministic=False,
benchmark=True)
# Others.
self.pretrained_weight = ''
self.inference_args = AttrDict()
# Update with given configurations.
assert os.path.exists(filename), 'File {} not exist.'.format(filename)
loader = yaml.SafeLoader
loader.add_implicit_resolver(
u'tag:yaml.org,2002:float',
re.compile(u'''^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$''', re.X),
list(u'-+0123456789.'))
try:
with open(filename, 'r') as f:
cfg_dict = yaml.load(f, Loader=loader)
except EnvironmentError:
print('Please check the file with name of "%s"', filename)
recursive_update(self, cfg_dict)
# Put common opts in both gen and dis.
if 'common' in cfg_dict:
self.common = AttrDict(**cfg_dict['common'])
self.gen.common = self.common
self.dis.common = self.common
if verbose:
print(' imaginaire config '.center(80, '-'))
print(self.__repr__())
print(''.center(80, '-'))
def rsetattr(obj, attr, val):
"""Recursively find object and set value"""
pre, _, post = attr.rpartition('.')
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
def rgetattr(obj, attr, *args):
"""Recursively find object and return value"""
def _getattr(obj, attr):
r"""Get attribute."""
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
def recursive_update(d, u):
"""Recursively update AttrDict d with AttrDict u"""
for key, value in u.items():
if isinstance(value, collections.abc.Mapping):
d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
elif isinstance(value, (list, tuple)):
if isinstance(value[0], dict):
d.__dict__[key] = [AttrDict(item) for item in value]
else:
d.__dict__[key] = value
else:
d.__dict__[key] = value
return d