File size: 3,596 Bytes
4673b21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import os
import sys
import json
import random
from ast import literal_eval
import numpy as np
import torch
# -----------------------------------------------------------------------------
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def setup_logging(config):
""" monotonous bookkeeping """
work_dir = config.system.work_dir
# create the work directory if it doesn't already exist
os.makedirs(work_dir, exist_ok=True)
# log the args (if any)
with open(os.path.join(work_dir, 'args.txt'), 'w') as f:
f.write(' '.join(sys.argv))
# log the config itself
with open(os.path.join(work_dir, 'config.json'), 'w') as f:
f.write(json.dumps(config.to_dict(), indent=4))
class CfgNode:
""" a lightweight configuration class inspired by yacs """
# TODO: convert to subclass from a dict like in yacs?
# TODO: implement freezing to prevent shooting of own foot
# TODO: additional existence/override checks when reading/writing params?
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __str__(self):
return self._str_helper(0)
def _str_helper(self, indent):
""" need to have a helper to support nested indentation for pretty printing """
parts = []
for k, v in self.__dict__.items():
if isinstance(v, CfgNode):
parts.append("%s:\n" % k)
parts.append(v._str_helper(indent + 1))
else:
parts.append("%s: %s\n" % (k, v))
parts = [' ' * (indent * 4) + p for p in parts]
return "".join(parts)
def to_dict(self):
""" return a dict representation of the config """
return { k: v.to_dict() if isinstance(v, CfgNode) else v for k, v in self.__dict__.items() }
def merge_from_dict(self, d):
self.__dict__.update(d)
def merge_from_args(self, args):
"""
update the configuration from a list of strings that is expected
to come from the command line, i.e. sys.argv[1:].
The arguments are expected to be in the form of `--arg=value`, and
the arg can use . to denote nested sub-attributes. Example:
--model.n_layer=10 --trainer.batch_size=32
"""
for arg in args:
keyval = arg.split('=')
assert len(keyval) == 2, "expecting each override arg to be of form --arg=value, got %s" % arg
key, val = keyval # unpack
# first translate val into a python object
try:
val = literal_eval(val)
"""
need some explanation here.
- if val is simply a string, literal_eval will throw a ValueError
- if val represents a thing (like an 3, 3.14, [1,2,3], False, None, etc.) it will get created
"""
except ValueError:
pass
# find the appropriate object to insert the attribute into
assert key[:2] == '--'
key = key[2:] # strip the '--'
keys = key.split('.')
obj = self
for k in keys[:-1]:
obj = getattr(obj, k)
leaf_key = keys[-1]
# ensure that this attribute exists
assert hasattr(obj, leaf_key), f"{key} is not an attribute that exists in the config"
# overwrite the attribute
print("command line overwriting config attribute %s with %s" % (key, val))
setattr(obj, leaf_key, val)
|