Spaces:
Running
Running
File size: 5,215 Bytes
5b83793 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
import os.path as osp
import logging
import yaml
import sys
sys.path.append('../../..')
from SRFlow.code.utils.util import OrderedYaml
Loader, Dumper = OrderedYaml()
def parse(opt_path, is_train=True):
with open(opt_path, mode='r') as f:
opt = yaml.load(f, Loader=Loader)
# export CUDA_VISIBLE_DEVICES
gpu_list = ','.join(str(x) for x in opt.get('gpu_ids', []))
# os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
# print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
opt['is_train'] = is_train
if opt['distortion'] == 'sr':
scale = opt['scale']
# datasets
for phase, dataset in opt['datasets'].items():
phase = phase.split('_')[0]
dataset['phase'] = phase
if opt['distortion'] == 'sr':
dataset['scale'] = scale
is_lmdb = False
if dataset.get('dataroot_GT', None) is not None:
dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
if dataset['dataroot_GT'].endswith('lmdb'):
is_lmdb = True
if dataset.get('dataroot_LQ', None) is not None:
dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
if dataset['dataroot_LQ'].endswith('lmdb'):
is_lmdb = True
dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
if dataset['mode'].endswith('mc'): # for memcached
dataset['data_type'] = 'mc'
dataset['mode'] = dataset['mode'].replace('_mc', '')
# path
for key, path in opt['path'].items():
if path and key in opt['path'] and key != 'strict_load':
opt['path'][key] = osp.expanduser(path)
opt['path']['root'] = '/kaggle/working/'
if is_train:
experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
opt['path']['experiments_root'] = experiments_root
opt['path']['models'] = osp.join(experiments_root, 'models')
opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
opt['path']['log'] = experiments_root
opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
# change some options for debug mode
if 'debug' in opt['name']:
opt['train']['val_freq'] = 8
opt['logger']['print_freq'] = 1
opt['logger']['save_checkpoint_freq'] = 8
else: # test
if not opt['path'].get('results_root', None):
results_root = osp.join(opt['path']['root'], 'results', opt['name'])
opt['path']['results_root'] = results_root
opt['path']['log'] = opt['path']['results_root']
# network
if opt['distortion'] == 'sr':
opt['network_G']['scale'] = scale
# relative learning rate
if 'train' in opt:
niter = opt['train']['niter']
if 'T_period_rel' in opt['train']:
opt['train']['T_period'] = [int(x * niter) for x in opt['train']['T_period_rel']]
if 'restarts_rel' in opt['train']:
opt['train']['restarts'] = [int(x * niter) for x in opt['train']['restarts_rel']]
if 'lr_steps_rel' in opt['train']:
opt['train']['lr_steps'] = [int(x * niter) for x in opt['train']['lr_steps_rel']]
if 'lr_steps_inverse_rel' in opt['train']:
opt['train']['lr_steps_inverse'] = [int(x * niter) for x in opt['train']['lr_steps_inverse_rel']]
print(opt['train'])
return opt
def dict2str(opt, indent_l=1):
'''dict to string for logger'''
msg = ''
for k, v in opt.items():
if isinstance(v, dict):
msg += ' ' * (indent_l * 2) + k + ':[\n'
msg += dict2str(v, indent_l + 1)
msg += ' ' * (indent_l * 2) + ']\n'
else:
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
return msg
class NoneDict(dict):
def __missing__(self, key):
return None
# convert to NoneDict, which return None for missing key.
def dict_to_nonedict(opt):
if isinstance(opt, dict):
new_opt = dict()
for key, sub_opt in opt.items():
new_opt[key] = dict_to_nonedict(sub_opt)
return NoneDict(**new_opt)
elif isinstance(opt, list):
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
else:
return opt
def check_resume(opt, resume_iter):
'''Check resume states and pretrain_model paths'''
logger = logging.getLogger('base')
if opt['path']['resume_state']:
if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
'pretrain_model_D', None) is not None:
logger.warning('pretrain_model path will be ignored when resuming training.')
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
'{}_G.pth'.format(resume_iter))
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
if 'gan' in opt['model']:
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
'{}_D.pth'.format(resume_iter))
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
|