import sys import torch import yaml def load_yaml_config(path): with open(path) as f: config = yaml.full_load(f) return config def save_config_to_yaml(config, path): assert path.endswith('.yaml') with open(path, 'w') as f: f.write(yaml.dump(config)) f.close() def write_args(args, path): args_dict = dict((name, getattr(args, name)) for name in dir(args) if not name.startswith('_')) with open(path, 'a') as args_file: args_file.write('==> torch version: {}\n'.format(torch.__version__)) args_file.write( '==> cudnn version: {}\n'.format(torch.backends.cudnn.version())) args_file.write('==> Cmd:\n') args_file.write(str(sys.argv)) args_file.write('\n==> args:\n') for k, v in sorted(args_dict.items()): args_file.write(' %s: %s\n' % (str(k), str(v))) args_file.close()