Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
"""Config utilities for yml file.""" | |
import collections | |
import functools | |
import os | |
import re | |
import yaml | |
from imaginaire.utils.distributed import master_only_print as print | |
DEBUG = False | |
USE_JIT = False | |
class AttrDict(dict): | |
"""Dict as attribute trick.""" | |
def __init__(self, *args, **kwargs): | |
super(AttrDict, self).__init__(*args, **kwargs) | |
self.__dict__ = self | |
for key, value in self.__dict__.items(): | |
if isinstance(value, dict): | |
self.__dict__[key] = AttrDict(value) | |
elif isinstance(value, (list, tuple)): | |
if isinstance(value[0], dict): | |
self.__dict__[key] = [AttrDict(item) for item in value] | |
else: | |
self.__dict__[key] = value | |
def yaml(self): | |
"""Convert object to yaml dict and return.""" | |
yaml_dict = {} | |
for key, value in self.__dict__.items(): | |
if isinstance(value, AttrDict): | |
yaml_dict[key] = value.yaml() | |
elif isinstance(value, list): | |
if isinstance(value[0], AttrDict): | |
new_l = [] | |
for item in value: | |
new_l.append(item.yaml()) | |
yaml_dict[key] = new_l | |
else: | |
yaml_dict[key] = value | |
else: | |
yaml_dict[key] = value | |
return yaml_dict | |
def __repr__(self): | |
"""Print all variables.""" | |
ret_str = [] | |
for key, value in self.__dict__.items(): | |
if isinstance(value, AttrDict): | |
ret_str.append('{}:'.format(key)) | |
child_ret_str = value.__repr__().split('\n') | |
for item in child_ret_str: | |
ret_str.append(' ' + item) | |
elif isinstance(value, list): | |
if isinstance(value[0], AttrDict): | |
ret_str.append('{}:'.format(key)) | |
for item in value: | |
# Treat as AttrDict above. | |
child_ret_str = item.__repr__().split('\n') | |
for item in child_ret_str: | |
ret_str.append(' ' + item) | |
else: | |
ret_str.append('{}: {}'.format(key, value)) | |
else: | |
ret_str.append('{}: {}'.format(key, value)) | |
return '\n'.join(ret_str) | |
class Config(AttrDict): | |
r"""Configuration class. This should include every human specifiable | |
hyperparameter values for your training.""" | |
def __init__(self, filename=None, verbose=False): | |
super(Config, self).__init__() | |
self.source_filename = filename | |
# Set default parameters. | |
# Logging. | |
large_number = 1000000000 | |
self.snapshot_save_iter = large_number | |
self.snapshot_save_epoch = large_number | |
self.metrics_iter = None | |
self.metrics_epoch = None | |
self.snapshot_save_start_iter = 0 | |
self.snapshot_save_start_epoch = 0 | |
self.image_save_iter = large_number | |
self.image_display_iter = large_number | |
self.max_epoch = large_number | |
self.max_iter = large_number | |
self.logging_iter = 100 | |
self.speed_benchmark = False | |
# Trainer. | |
self.trainer = AttrDict( | |
model_average_config=AttrDict(enabled=False, | |
beta=0.9999, | |
start_iteration=1000, | |
num_batch_norm_estimation_iterations=30, | |
remove_sn=True), | |
# model_average=False, | |
# model_average_beta=0.9999, | |
# model_average_start_iteration=1000, | |
# model_average_batch_norm_estimation_iteration=30, | |
# model_average_remove_sn=True, | |
image_to_tensorboard=False, | |
hparam_to_tensorboard=False, | |
distributed_data_parallel='pytorch', | |
distributed_data_parallel_params=AttrDict( | |
find_unused_parameters=False), | |
delay_allreduce=True, | |
gan_relativistic=False, | |
gen_step=1, | |
dis_step=1, | |
gan_decay_k=1., | |
gan_min_k=1., | |
gan_separate_topk=False, | |
aug_policy='', | |
channels_last=False, | |
strict_resume=True, | |
amp_gp=False, | |
amp_config=AttrDict(init_scale=65536.0, | |
growth_factor=2.0, | |
backoff_factor=0.5, | |
growth_interval=2000, | |
enabled=False)) | |
# Networks. | |
self.gen = AttrDict(type='imaginaire.generators.dummy') | |
self.dis = AttrDict(type='imaginaire.discriminators.dummy') | |
# Optimizers. | |
self.gen_opt = AttrDict(type='adam', | |
fused_opt=False, | |
lr=0.0001, | |
adam_beta1=0.0, | |
adam_beta2=0.999, | |
eps=1e-8, | |
lr_policy=AttrDict(iteration_mode=False, | |
type='step', | |
step_size=large_number, | |
gamma=1)) | |
self.dis_opt = AttrDict(type='adam', | |
fused_opt=False, | |
lr=0.0001, | |
adam_beta1=0.0, | |
adam_beta2=0.999, | |
eps=1e-8, | |
lr_policy=AttrDict(iteration_mode=False, | |
type='step', | |
step_size=large_number, | |
gamma=1)) | |
# Data. | |
self.data = AttrDict(name='dummy', | |
type='imaginaire.datasets.images', | |
num_workers=0) | |
self.test_data = AttrDict(name='dummy', | |
type='imaginaire.datasets.images', | |
num_workers=0, | |
test=AttrDict(is_lmdb=False, | |
roots='', | |
batch_size=1)) | |
# Cudnn. | |
self.cudnn = AttrDict(deterministic=False, | |
benchmark=True) | |
# Others. | |
self.pretrained_weight = '' | |
self.inference_args = AttrDict() | |
# Update with given configurations. | |
assert os.path.exists(filename), 'File {} not exist.'.format(filename) | |
loader = yaml.SafeLoader | |
loader.add_implicit_resolver( | |
u'tag:yaml.org,2002:float', | |
re.compile(u'''^(?: | |
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? | |
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) | |
|\\.[0-9_]+(?:[eE][-+][0-9]+)? | |
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* | |
|[-+]?\\.(?:inf|Inf|INF) | |
|\\.(?:nan|NaN|NAN))$''', re.X), | |
list(u'-+0123456789.')) | |
try: | |
with open(filename, 'r') as f: | |
cfg_dict = yaml.load(f, Loader=loader) | |
except EnvironmentError: | |
print('Please check the file with name of "%s"', filename) | |
recursive_update(self, cfg_dict) | |
# Put common opts in both gen and dis. | |
if 'common' in cfg_dict: | |
self.common = AttrDict(**cfg_dict['common']) | |
self.gen.common = self.common | |
self.dis.common = self.common | |
if verbose: | |
print(' imaginaire config '.center(80, '-')) | |
print(self.__repr__()) | |
print(''.center(80, '-')) | |
def rsetattr(obj, attr, val): | |
"""Recursively find object and set value""" | |
pre, _, post = attr.rpartition('.') | |
return setattr(rgetattr(obj, pre) if pre else obj, post, val) | |
def rgetattr(obj, attr, *args): | |
"""Recursively find object and return value""" | |
def _getattr(obj, attr): | |
r"""Get attribute.""" | |
return getattr(obj, attr, *args) | |
return functools.reduce(_getattr, [obj] + attr.split('.')) | |
def recursive_update(d, u): | |
"""Recursively update AttrDict d with AttrDict u""" | |
for key, value in u.items(): | |
if isinstance(value, collections.abc.Mapping): | |
d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value) | |
elif isinstance(value, (list, tuple)): | |
if isinstance(value[0], dict): | |
d.__dict__[key] = [AttrDict(item) for item in value] | |
else: | |
d.__dict__[key] = value | |
else: | |
d.__dict__[key] = value | |
return d | |