hail75 commited on
Commit
d2821a4
1 Parent(s): d698613

add train.py

Browse files
models/SRFlow/code/data/LRHR_PKL_dataset.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ import os
18
+ import subprocess
19
+ import torch.utils.data as data
20
+ import numpy as np
21
+ import time
22
+ import torch
23
+
24
+ import pickle
25
+
26
+
27
+ class LRHR_PKLDataset(data.Dataset):
28
+ def __init__(self, opt):
29
+ super(LRHR_PKLDataset, self).__init__()
30
+ self.opt = opt
31
+ self.crop_size = opt.get("GT_size", None)
32
+ self.scale = None
33
+ self.random_scale_list = [1]
34
+
35
+ hr_file_path = opt["dataroot_GT"]
36
+ lr_file_path = opt["dataroot_LQ"]
37
+ y_labels_file_path = opt['dataroot_y_labels']
38
+
39
+ gpu = True
40
+ augment = True
41
+
42
+ self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False
43
+ self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False
44
+ self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False
45
+ self.center_crop_hr_size = opt.get("center_crop_hr_size", None)
46
+
47
+ n_max = opt["n_max"] if "n_max" in opt.keys() else int(1e8)
48
+
49
+ t = time.time()
50
+ self.lr_images = self.load_pkls(lr_file_path, n_max)
51
+ self.hr_images = self.load_pkls(hr_file_path, n_max)
52
+
53
+ min_val_hr = np.min([i.min() for i in self.hr_images[:20]])
54
+ max_val_hr = np.max([i.max() for i in self.hr_images[:20]])
55
+
56
+ min_val_lr = np.min([i.min() for i in self.lr_images[:20]])
57
+ max_val_lr = np.max([i.max() for i in self.lr_images[:20]])
58
+
59
+ t = time.time() - t
60
+ print("Loaded {} HR images with [{:.2f}, {:.2f}] in {:.2f}s from {}".
61
+ format(len(self.hr_images), min_val_hr, max_val_hr, t, hr_file_path))
62
+ print("Loaded {} LR images with [{:.2f}, {:.2f}] in {:.2f}s from {}".
63
+ format(len(self.lr_images), min_val_lr, max_val_lr, t, lr_file_path))
64
+
65
+ self.gpu = gpu
66
+ self.augment = augment
67
+
68
+ self.measures = None
69
+
70
+ def load_pkls(self, path, n_max):
71
+ assert os.path.isfile(path), path
72
+ images = []
73
+ with open(path, "rb") as f:
74
+ images += pickle.load(f)
75
+ assert len(images) > 0, path
76
+ images = images[:n_max]
77
+ images = [np.transpose(image, [2, 0, 1]) for image in images]
78
+ return images
79
+
80
+ def __len__(self):
81
+ return len(self.hr_images)
82
+
83
+ def __getitem__(self, item):
84
+ hr = self.hr_images[item]
85
+ lr = self.lr_images[item]
86
+
87
+ if self.scale == None:
88
+ self.scale = hr.shape[1] // lr.shape[1]
89
+ assert hr.shape[1] == self.scale * lr.shape[1], ('non-fractional ratio', lr.shape, hr.shape)
90
+
91
+ if self.use_crop:
92
+ hr, lr = random_crop(hr, lr, self.crop_size, self.scale, self.use_crop)
93
+
94
+ if self.center_crop_hr_size:
95
+ hr, lr = center_crop(hr, self.center_crop_hr_size), center_crop(lr, self.center_crop_hr_size // self.scale)
96
+
97
+ if self.use_flip:
98
+ hr, lr = random_flip(hr, lr)
99
+
100
+ if self.use_rot:
101
+ hr, lr = random_rotation(hr, lr)
102
+
103
+ hr = hr / 255.0
104
+ lr = lr / 255.0
105
+
106
+ if self.measures is None or np.random.random() < 0.05:
107
+ if self.measures is None:
108
+ self.measures = {}
109
+ self.measures['hr_means'] = np.mean(hr)
110
+ self.measures['hr_stds'] = np.std(hr)
111
+ self.measures['lr_means'] = np.mean(lr)
112
+ self.measures['lr_stds'] = np.std(lr)
113
+
114
+ hr = torch.Tensor(hr)
115
+ lr = torch.Tensor(lr)
116
+
117
+ # if self.gpu:
118
+ # hr = hr.cuda()
119
+ # lr = lr.cuda()
120
+
121
+ return {'LQ': lr, 'GT': hr, 'LQ_path': str(item), 'GT_path': str(item)}
122
+
123
+ def print_and_reset(self, tag):
124
+ m = self.measures
125
+ kvs = []
126
+ for k in sorted(m.keys()):
127
+ kvs.append("{}={:.2f}".format(k, m[k]))
128
+ print("[KPI] " + tag + ": " + ", ".join(kvs))
129
+ self.measures = None
130
+
131
+
132
+ def random_flip(img, seg):
133
+ random_choice = np.random.choice([True, False])
134
+ img = img if random_choice else np.flip(img, 2).copy()
135
+ seg = seg if random_choice else np.flip(seg, 2).copy()
136
+ return img, seg
137
+
138
+
139
+ def random_rotation(img, seg):
140
+ random_choice = np.random.choice([0, 1, 3])
141
+ img = np.rot90(img, random_choice, axes=(1, 2)).copy()
142
+ seg = np.rot90(seg, random_choice, axes=(1, 2)).copy()
143
+ return img, seg
144
+
145
+
146
+ def random_crop(hr, lr, size_hr, scale, random):
147
+ size_lr = size_hr // scale
148
+
149
+ size_lr_x = lr.shape[1]
150
+ size_lr_y = lr.shape[2]
151
+
152
+ start_x_lr = np.random.randint(low=0, high=(size_lr_x - size_lr) + 1) if size_lr_x > size_lr else 0
153
+ start_y_lr = np.random.randint(low=0, high=(size_lr_y - size_lr) + 1) if size_lr_y > size_lr else 0
154
+
155
+ # LR Patch
156
+ lr_patch = lr[:, start_x_lr:start_x_lr + size_lr, start_y_lr:start_y_lr + size_lr]
157
+
158
+ # HR Patch
159
+ start_x_hr = start_x_lr * scale
160
+ start_y_hr = start_y_lr * scale
161
+ hr_patch = hr[:, start_x_hr:start_x_hr + size_hr, start_y_hr:start_y_hr + size_hr]
162
+
163
+ return hr_patch, lr_patch
164
+
165
+
166
+ def center_crop(img, size):
167
+ assert img.shape[1] == img.shape[2], img.shape
168
+ border_double = img.shape[1] - size
169
+ assert border_double % 2 == 0, (img.shape, size)
170
+ border = border_double // 2
171
+ return img[:, border:-border, border:-border]
172
+
173
+
174
+ def center_crop_tensor(img, size):
175
+ assert img.shape[2] == img.shape[3], img.shape
176
+ border_double = img.shape[2] - size
177
+ assert border_double % 2 == 0, (img.shape, size)
178
+ border = border_double // 2
179
+ return img[:, :, border:-border, border:-border]
models/SRFlow/code/data/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ '''create dataset and dataloader'''
18
+ import logging
19
+ import torch
20
+ import torch.utils.data
21
+
22
+
23
+ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
24
+ phase = dataset_opt.get('phase', 'test')
25
+ if phase == 'train':
26
+ gpu_ids = opt.get('gpu_ids', None)
27
+ gpu_ids = gpu_ids if gpu_ids else []
28
+ num_workers = dataset_opt['n_workers'] * len(gpu_ids)
29
+ batch_size = dataset_opt['batch_size']
30
+ shuffle = True
31
+ return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
32
+ num_workers=num_workers, sampler=sampler, drop_last=True,
33
+ pin_memory=False)
34
+ else:
35
+ return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
36
+ pin_memory=True)
37
+
38
+
39
+ def create_dataset(dataset_opt):
40
+ print(dataset_opt)
41
+ mode = dataset_opt['mode']
42
+ if mode == 'LRHR_PKL':
43
+ from data.LRHR_PKL_dataset import LRHR_PKLDataset as D
44
+ else:
45
+ raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
46
+ dataset = D(dataset_opt)
47
+
48
+ logger = logging.getLogger('base')
49
+ logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
50
+ dataset_opt['name']))
51
+ return dataset
models/SRFlow/code/train.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16
+
17
+ import os
18
+ from os.path import basename
19
+ import math
20
+ import argparse
21
+ import random
22
+ import logging
23
+ import cv2
24
+
25
+ import torch
26
+ import torch.distributed as dist
27
+ import torch.multiprocessing as mp
28
+
29
+ import options.options as option
30
+ from utils import util
31
+ from data import create_dataloader, create_dataset
32
+ from models import create_model
33
+ from utils.timer import Timer, TickTock
34
+ from utils.util import get_resume_paths
35
+
36
+ import wandb
37
+
38
+ def getEnv(name): import os; return True if name in os.environ.keys() else False
39
+
40
+
41
+ def init_dist(backend='nccl', **kwargs):
42
+ ''' initialization for distributed training'''
43
+ # if mp.get_start_method(allow_none=True) is None:
44
+ if mp.get_start_method(allow_none=True) != 'spawn':
45
+ mp.set_start_method('spawn')
46
+ rank = int(os.environ['RANK'])
47
+ num_gpus = torch.cuda.device_count()
48
+ torch.cuda.set_deviceDistIterSampler(rank % num_gpus)
49
+ dist.init_process_group(backend=backend, **kwargs)
50
+
51
+
52
+ def main():
53
+ wandb.init(project='srflow')
54
+ #### options
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
57
+ parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
58
+ help='job launcher')
59
+ parser.add_argument('--local_rank', type=int, default=0)
60
+ args = parser.parse_args()
61
+ opt = option.parse(args.opt, is_train=True)
62
+
63
+ #### distributed training settings
64
+ opt['dist'] = False
65
+ rank = -1
66
+ print('Disabled distributed training.')
67
+
68
+ #### loading resume state if exists
69
+ if opt['path'].get('resume_state', None):
70
+ resume_state_path, _ = get_resume_paths(opt)
71
+
72
+ # distributed resuming: all load into default GPU
73
+ if resume_state_path is None:
74
+ resume_state = None
75
+ else:
76
+ device_id = torch.cuda.current_device()
77
+ resume_state = torch.load(resume_state_path,
78
+ map_location=lambda storage, loc: storage.cuda(device_id))
79
+ option.check_resume(opt, resume_state['iter']) # check resume options
80
+ else:
81
+ resume_state = None
82
+
83
+ #### mkdir and loggers
84
+ if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
85
+ if resume_state is None:
86
+ util.mkdir_and_rename(
87
+ opt['path']['experiments_root']) # rename experiment folder if exists
88
+ util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
89
+ and 'pretrain_model' not in key and 'resume' not in key))
90
+
91
+ # config loggers. Before it, the log will not work
92
+ util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
93
+ screen=True, tofile=True)
94
+ util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
95
+ screen=True, tofile=True)
96
+ logger = logging.getLogger('base')
97
+ logger.info(option.dict2str(opt))
98
+
99
+ # tensorboard logger
100
+ if opt.get('use_tb_logger', False) and 'debug' not in opt['name']:
101
+ version = float(torch.__version__[0:3])
102
+ if version >= 1.1: # PyTorch 1.1
103
+ from torch.utils.tensorboard import SummaryWriter
104
+ else:
105
+ logger.info(
106
+ 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
107
+ from tensorboardX import SummaryWriter
108
+ conf_name = basename(args.opt).replace(".yml", "")
109
+ exp_dir = opt['path']['experiments_root']
110
+ log_dir_train = os.path.join(exp_dir, 'tb', conf_name, 'train')
111
+ log_dir_valid = os.path.join(exp_dir, 'tb', conf_name, 'valid')
112
+ tb_logger_train = SummaryWriter(log_dir=log_dir_train)
113
+ tb_logger_valid = SummaryWriter(log_dir=log_dir_valid)
114
+ else:
115
+ util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
116
+ logger = logging.getLogger('base')
117
+
118
+ # convert to NoneDict, which returns None for missing keys
119
+ opt = option.dict_to_nonedict(opt)
120
+
121
+ #### random seed
122
+ seed = opt['train']['manual_seed']
123
+ if seed is None:
124
+ seed = random.randint(1, 10000)
125
+ if rank <= 0:
126
+ logger.info('Random seed: {}'.format(seed))
127
+ util.set_random_seed(seed)
128
+
129
+ torch.backends.cudnn.benchmark = True
130
+ # torch.backends.cudnn.deterministic = True
131
+
132
+ #### create train and val dataloader
133
+ dataset_ratio = 200 # enlarge the size of each epoch
134
+ for phase, dataset_opt in opt['datasets'].items():
135
+ if phase == 'train':
136
+ full_dataset = create_dataset(dataset_opt)
137
+ print('Dataset created')
138
+ train_len = int(len(full_dataset) * 0.95)
139
+ val_len = len(full_dataset) - train_len
140
+ train_set, val_set = torch.utils.data.random_split(full_dataset, [train_len, val_len])
141
+ train_size = int(math.ceil(train_len / dataset_opt['batch_size']))
142
+ total_iters = int(opt['train']['niter'])
143
+ total_epochs = int(math.ceil(total_iters / train_size))
144
+ train_sampler = None
145
+ train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
146
+ if rank <= 0:
147
+ logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
148
+ len(train_set), train_size))
149
+ logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
150
+ total_epochs, total_iters))
151
+ val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1,
152
+ pin_memory=True)
153
+ elif phase == 'val':
154
+ continue
155
+ else:
156
+ raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
157
+ assert train_loader is not None
158
+
159
+ #### create model
160
+ current_step = 0 if resume_state is None else resume_state['iter']
161
+ model = create_model(opt, current_step)
162
+
163
+ #### resume training
164
+ if resume_state:
165
+ logger.info('Resuming training from epoch: {}, iter: {}.'.format(
166
+ resume_state['epoch'], resume_state['iter']))
167
+
168
+ start_epoch = resume_state['epoch']
169
+ current_step = resume_state['iter']
170
+ model.resume_training(resume_state) # handle optimizers and schedulers
171
+ else:
172
+ current_step = 0
173
+ start_epoch = 0
174
+
175
+ #### training
176
+ timer = Timer()
177
+ logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
178
+ timerData = TickTock()
179
+
180
+ for epoch in range(start_epoch, total_epochs + 1):
181
+ if opt['dist']:
182
+ train_sampler.set_epoch(epoch)
183
+
184
+ timerData.tick()
185
+ for _, train_data in enumerate(train_loader):
186
+ timerData.tock()
187
+ current_step += 1
188
+ if current_step > total_iters:
189
+ break
190
+
191
+ #### training
192
+ model.feed_data(train_data)
193
+
194
+ #### update learning rate
195
+ model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
196
+
197
+ try:
198
+ nll = model.optimize_parameters(current_step)
199
+ except RuntimeError as e:
200
+ print("Skipping ERROR caught in nll = model.optimize_parameters(current_step): ")
201
+ print(e)
202
+
203
+ if nll is None:
204
+ nll = 0
205
+
206
+ wandb.log({"loss": nll})
207
+ #### log
208
+ def eta(t_iter):
209
+ return (t_iter * (opt['train']['niter'] - current_step)) / 3600
210
+
211
+ if current_step % opt['logger']['print_freq'] == 0 \
212
+ or current_step - (resume_state['iter'] if resume_state else 0) < 25:
213
+ avg_time = timer.get_average_and_reset()
214
+ avg_data_time = timerData.get_average_and_reset()
215
+ message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, t:{:.2e}, td:{:.2e}, eta:{:.2e}, nll:{:.3e}> '.format(
216
+ epoch, current_step, model.get_current_learning_rate(), avg_time, avg_data_time,
217
+ eta(avg_time), nll)
218
+ print(message)
219
+ timer.tick()
220
+ # Reduce number of logs
221
+ if current_step % 5 == 0:
222
+ tb_logger_train.add_scalar('loss/nll', nll, current_step)
223
+ tb_logger_train.add_scalar('lr/base', model.get_current_learning_rate(), current_step)
224
+ tb_logger_train.add_scalar('time/iteration', timer.get_last_iteration(), current_step)
225
+ tb_logger_train.add_scalar('time/data', timerData.get_last_iteration(), current_step)
226
+ tb_logger_train.add_scalar('time/eta', eta(timer.get_last_iteration()), current_step)
227
+ for k, v in model.get_current_log().items():
228
+ tb_logger_train.add_scalar(k, v, current_step)
229
+
230
+ # validation
231
+ if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
232
+ avg_psnr = 0.0
233
+ idx = 0
234
+ nlls = []
235
+ for val_data in val_loader:
236
+ idx += 1
237
+ img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
238
+ img_dir = os.path.join(opt['path']['val_images'], img_name)
239
+ util.mkdir(img_dir)
240
+
241
+ model.feed_data(val_data)
242
+
243
+ nll = model.test()
244
+ if nll is None:
245
+ nll = 0
246
+ nlls.append(nll)
247
+
248
+ visuals = model.get_current_visuals()
249
+
250
+ sr_img = None
251
+ # Save SR images for reference
252
+ if hasattr(model, 'heats'):
253
+ for heat in model.heats:
254
+ for i in range(model.n_sample):
255
+ sr_img = util.tensor2img(visuals['SR', heat, i]) # uint8
256
+ save_img_path = os.path.join(img_dir,
257
+ '{:s}_{:09d}_h{:03d}_s{:d}.png'.format(img_name,
258
+ current_step,
259
+ int(heat * 100), i))
260
+ util.save_img(sr_img, save_img_path)
261
+ else:
262
+ sr_img = util.tensor2img(visuals['SR']) # uint8
263
+ save_img_path = os.path.join(img_dir,
264
+ '{:s}_{:d}.png'.format(img_name, current_step))
265
+ util.save_img(sr_img, save_img_path)
266
+ assert sr_img is not None
267
+
268
+ # Save LQ images for reference
269
+ save_img_path_lq = os.path.join(img_dir,
270
+ '{:s}_LQ.png'.format(img_name))
271
+ if not os.path.isfile(save_img_path_lq):
272
+ lq_img = util.tensor2img(visuals['LQ']) # uint8
273
+ util.save_img(
274
+ cv2.resize(lq_img, dsize=None, fx=opt['scale'], fy=opt['scale'],
275
+ interpolation=cv2.INTER_NEAREST),
276
+ save_img_path_lq)
277
+
278
+ # Save GT images for reference
279
+ gt_img = util.tensor2img(visuals['GT']) # uint8
280
+ save_img_path_gt = os.path.join(img_dir,
281
+ '{:s}_GT.png'.format(img_name))
282
+ if not os.path.isfile(save_img_path_gt):
283
+ util.save_img(gt_img, save_img_path_gt)
284
+
285
+ # calculate PSNR
286
+ crop_size = opt['scale']
287
+ gt_img = gt_img / 255.
288
+ sr_img = sr_img / 255.
289
+ cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
290
+ cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
291
+ avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
292
+
293
+ avg_psnr = avg_psnr / idx
294
+ avg_nll = sum(nlls) / len(nlls)
295
+
296
+ # log
297
+ logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
298
+ logger_val = logging.getLogger('val') # validation logger
299
+ logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
300
+ epoch, current_step, avg_psnr))
301
+
302
+ # tensorboard logger
303
+ tb_logger_valid.add_scalar('loss/psnr', avg_psnr, current_step)
304
+ tb_logger_valid.add_scalar('loss/nll', avg_nll, current_step)
305
+
306
+ tb_logger_train.flush()
307
+ tb_logger_valid.flush()
308
+
309
+ #### save models and training states
310
+ if current_step % opt['logger']['save_checkpoint_freq'] == 0:
311
+ if rank <= 0:
312
+ logger.info('Saving models and training states.')
313
+ model.save(current_step)
314
+ model.save_training_state(epoch, current_step)
315
+
316
+ timerData.tick()
317
+
318
+ with open(os.path.join(opt['path']['root'], "TRAIN_DONE"), 'w') as f:
319
+ f.write("TRAIN_DONE")
320
+
321
+ if rank <= 0:
322
+ logger.info('Saving the final model.')
323
+ model.save('latest')
324
+ logger.info('End of training.')
325
+
326
+
327
+ if __name__ == '__main__':
328
+ main()