File size: 4,595 Bytes
6670ec8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import argparse
import random
import torch
from os import path as osp
from basicsr.data import create_dataloader, create_dataset
from basicsr.models import create_model
from basicsr.utils import (check_resume, make_exp_dirs, mkdir_and_rename, set_random_seed)
from basicsr.utils.dist_util import get_dist_info, init_dist
from basicsr.utils.options import parse
from basicsr.utils.nano import psf2otf
import numpy as np
from tqdm import tqdm
def parse_options(is_train=True):
parser = argparse.ArgumentParser()
parser.add_argument(
'-opt', type=str, required=True, help='Path to option YAML file.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm'],
default='none',
help='job launcher')
parser.add_argument(
'--name',
default=None,
help='job launcher')
import sys
vv = sys.version_info.minor
parser.add_argument('--local-rank', type=int, default=0)
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
opt = parse(args.opt, is_train=is_train, name=args.name if args.name is not None and args.name != "" else None)
# distributed settings
if args.launcher == 'none':
opt['dist'] = False
print('Disable distributed.', flush=True)
else:
opt['dist'] = True
if args.launcher == 'slurm' and 'dist_params' in opt:
init_dist(args.launcher, **opt['dist_params'])
else:
init_dist(args.launcher)
print('init dist .. ', args.launcher)
opt['rank'], opt['world_size'] = get_dist_info()
# random seed
seed = opt.get('manual_seed')
if seed is None:
seed = random.randint(1, 10000)
opt['manual_seed'] = seed
set_random_seed(seed + opt['rank'])
return opt
def main():
# parse options, set distributed setting, set ramdom seed
opt = parse_options(is_train=True)
torch.backends.cudnn.benchmark = True
# automatic resume ..
state_folder_path = 'experiments/{}/training_states/'.format(opt['name'])
import os
try:
states = os.listdir(state_folder_path)
except:
states = []
resume_state = None
if len(states) > 0:
max_state_file = '{}.state'.format(max([int(x[0:-6]) for x in states]))
resume_state = os.path.join(state_folder_path, max_state_file)
opt['path']['resume_state'] = resume_state
# load resume states if necessary
if opt['path'].get('resume_state'):
device_id = torch.cuda.current_device()
resume_state = torch.load(
opt['path']['resume_state'],
map_location=lambda storage, loc: storage.cuda(device_id))
else:
resume_state = None
# mkdir for experiments and logger
if resume_state is None:
make_exp_dirs(opt)
if opt['logger'].get('use_tb_logger') and 'debug' not in opt[
'name'] and opt['rank'] == 0:
mkdir_and_rename(osp.join('tb_logger', opt['name']))
# define ks for Wiener filters
ks_params = opt['train'].get('ks', None)
if not ks_params:
raise NotImplementedError
M = ks_params['num']
ks = torch.logspace(ks_params['start'], ks_params['end'], M)
ks = ks.view(1,M,1,1,1,1).to("cuda")
val_conv = opt['val'].get("apply_conv", True)
# create model
if resume_state: # resume training
check_resume(opt, resume_state['iter'])
model = create_model(opt)
model.resume_training(resume_state) # handle optimizers and schedulers
current_iter = resume_state['iter']
else:
model = create_model(opt)
current_iter = 0
# load psf
psf = torch.tensor(np.load("./psf.npy")).to("cuda")
_,psf_h,psf_w,_ = psf.shape
otf = psf2otf(psf, h=psf_h*3, w=psf_w*3, permute=True)[None]
dataset_opt = opt['datasets']['val']
val_set = create_dataset(dataset_opt)
val_loader = create_dataloader(
val_set,
dataset_opt,
num_gpu=opt['num_gpu'],
dist=opt['dist'],
sampler=None,
seed=opt['manual_seed'])
print("Start validation on spatially varying aberrration")
rgb2bgr = opt['val'].get('rgb2bgr', True)
use_image = opt['val'].get('use_image', True)
psnr, others = model.validation(val_loader, current_iter, None, True, rgb2bgr, use_image, psf=otf, ks=ks, val_conv=val_conv)
print("==================")
print(f"Test results: PSNR: {psnr:.2f}, SSIM: {others['ssim']:.4f}, LPIPS: {others['lpips']:.4f}\n")
if __name__ == '__main__':
main()
|