File size: 4,244 Bytes
39aef76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os
import os.path as osp
import logging
from collections import OrderedDict
import json
from datetime import datetime
def mkdirs(paths):
if isinstance(paths, str):
os.makedirs(paths, exist_ok=True)
else:
for path in paths:
os.makedirs(path, exist_ok=True)
def get_timestamp():
return datetime.now().strftime('%y%m%d_%H%M%S')
def parse(args):
phase = args.phase
opt_path = args.config
gpu_ids = args.gpu_ids
enable_wandb = args.enable_wandb
# remove comments starting with '//'
json_str = ''
with open(opt_path, 'r') as f:
for line in f:
line = line.split('//')[0] + '\n'
json_str += line
opt = json.loads(json_str, object_pairs_hook=OrderedDict)
# set log directory
if args.debug:
opt['name'] = 'debug_{}'.format(opt['name'])
experiments_root = os.path.join(
'experiments', '{}_{}'.format(opt['name'], get_timestamp()))
opt['path']['experiments_root'] = experiments_root
for key, path in opt['path'].items():
if 'resume' not in key and 'experiments' not in key:
opt['path'][key] = os.path.join(experiments_root, path)
mkdirs(opt['path'][key])
# change dataset length limit
opt['phase'] = phase
# export CUDA_VISIBLE_DEVICES
if gpu_ids is not None:
opt['gpu_ids'] = [int(id) for id in gpu_ids.split(',')]
gpu_list = gpu_ids
else:
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
if len(gpu_list) > 1:
opt['distributed'] = True
else:
opt['distributed'] = False
# debug
if 'debug' in opt['name']:
opt['train']['val_freq'] = 2
opt['train']['print_freq'] = 2
opt['train']['save_checkpoint_freq'] = 3
opt['datasets']['train']['batch_size'] = 2
opt['model']['beta_schedule']['train']['n_timestep'] = 10
opt['model']['beta_schedule']['val']['n_timestep'] = 10
opt['datasets']['train']['data_len'] = 6
opt['datasets']['val']['data_len'] = 3
# validation in train phase
if phase == 'train':
opt['datasets']['val']['data_len'] = 3
# W&B Logging
try:
log_wandb_ckpt = args.log_wandb_ckpt
opt['log_wandb_ckpt'] = log_wandb_ckpt
except:
pass
try:
log_eval = args.log_eval
opt['log_eval'] = log_eval
except:
pass
try:
log_infer = args.log_infer
opt['log_infer'] = log_infer
except:
pass
opt['enable_wandb'] = enable_wandb
return opt
class NoneDict(dict):
def __missing__(self, key):
return None
# convert to NoneDict, which return None for missing key.
def dict_to_nonedict(opt):
if isinstance(opt, dict):
new_opt = dict()
for key, sub_opt in opt.items():
new_opt[key] = dict_to_nonedict(sub_opt)
return NoneDict(**new_opt)
elif isinstance(opt, list):
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
else:
return opt
def dict2str(opt, indent_l=1):
'''dict to string for logger'''
msg = ''
for k, v in opt.items():
if isinstance(v, dict):
msg += ' ' * (indent_l * 2) + k + ':[\n'
msg += dict2str(v, indent_l + 1)
msg += ' ' * (indent_l * 2) + ']\n'
else:
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
return msg
def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False):
'''set up logger'''
l = logging.getLogger(logger_name)
formatter = logging.Formatter(
'%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S')
log_file = os.path.join(root, '{}.log'.format(phase))
fh = logging.FileHandler(log_file, mode='w')
fh.setFormatter(formatter)
l.setLevel(level)
l.addHandler(fh)
if screen:
sh = logging.StreamHandler()
sh.setFormatter(formatter)
l.addHandler(sh)
|