|
|
import os |
|
|
import math |
|
|
import time |
|
|
import datetime |
|
|
from multiprocessing import Process |
|
|
from multiprocessing import Queue |
|
|
|
|
|
import matplotlib |
|
|
|
|
|
matplotlib.use('Agg') |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
import numpy as np |
|
|
import imageio |
|
|
import cv2 |
|
|
|
|
|
import torch |
|
|
import torch.optim as optim |
|
|
import torch.optim.lr_scheduler as lrs |
|
|
|
|
|
|
|
|
class timer(): |
|
|
def __init__(self): |
|
|
self.acc = 0 |
|
|
self.tic() |
|
|
|
|
|
def tic(self): |
|
|
self.t0 = time.time() |
|
|
|
|
|
def toc(self, restart=False): |
|
|
diff = time.time() - self.t0 |
|
|
if restart: self.t0 = time.time() |
|
|
return diff |
|
|
|
|
|
def hold(self): |
|
|
self.acc += self.toc() |
|
|
|
|
|
def release(self): |
|
|
ret = self.acc |
|
|
self.acc = 0 |
|
|
|
|
|
return ret |
|
|
|
|
|
def reset(self): |
|
|
self.acc = 0 |
|
|
|
|
|
|
|
|
class checkpoint(): |
|
|
def __init__(self, args): |
|
|
self.args = args |
|
|
self.ok = True |
|
|
self.log = torch.Tensor() |
|
|
now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') |
|
|
|
|
|
if not args.load: |
|
|
if not args.save: |
|
|
args.save = now |
|
|
self.dir = os.path.join('..', 'experiment', args.save) |
|
|
else: |
|
|
self.dir = os.path.join('..', 'experiment', args.load) |
|
|
if os.path.exists(self.dir): |
|
|
self.log = torch.load(self.get_path('psnr_log.pt')) |
|
|
print('Continue from epoch {}...'.format(len(self.log))) |
|
|
else: |
|
|
args.load = '' |
|
|
|
|
|
if args.reset: |
|
|
os.system('rm -rf ' + self.dir) |
|
|
args.load = '' |
|
|
|
|
|
os.makedirs(self.dir, exist_ok=True) |
|
|
os.makedirs(self.get_path('model'), exist_ok=True) |
|
|
for d in args.data_test: |
|
|
os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True) |
|
|
|
|
|
open_type = 'a' if os.path.exists(self.get_path('log.txt')) else 'w' |
|
|
self.log_file = open(self.get_path('log.txt'), open_type) |
|
|
with open(self.get_path('config.txt'), open_type) as f: |
|
|
f.write(now + '\n\n') |
|
|
for arg in vars(args): |
|
|
f.write('{}: {}\n'.format(arg, getattr(args, arg))) |
|
|
f.write('\n') |
|
|
|
|
|
self.n_processes = 8 |
|
|
|
|
|
def get_path(self, *subdir): |
|
|
return os.path.join(self.dir, *subdir) |
|
|
|
|
|
def save(self, trainer, epoch, is_best=False): |
|
|
trainer.model.save(self.get_path('model'), epoch, is_best=is_best) |
|
|
trainer.loss.save(self.dir) |
|
|
trainer.loss.plot_loss(self.dir, epoch) |
|
|
|
|
|
self.plot_psnr(epoch) |
|
|
trainer.optimizer.save(self.dir) |
|
|
torch.save(self.log, self.get_path('psnr_log.pt')) |
|
|
|
|
|
def add_log(self, log): |
|
|
self.log = torch.cat([self.log, log]) |
|
|
|
|
|
def write_log(self, log, refresh=False): |
|
|
print(log) |
|
|
self.log_file.write(log + '\n') |
|
|
if refresh: |
|
|
self.log_file.close() |
|
|
self.log_file = open(self.get_path('log.txt'), 'a') |
|
|
|
|
|
def done(self): |
|
|
self.log_file.close() |
|
|
|
|
|
def plot_psnr(self, epoch): |
|
|
axis = np.linspace(1, epoch, epoch) |
|
|
for idx_data, d in enumerate(self.args.data_test): |
|
|
label = 'SR on {}'.format(d) |
|
|
fig = plt.figure() |
|
|
plt.title(label) |
|
|
for idx_scale, scale in enumerate(self.args.scale): |
|
|
plt.plot( |
|
|
axis, |
|
|
self.log[:, idx_data, idx_scale].numpy(), |
|
|
label='Scale {}'.format(scale) |
|
|
) |
|
|
plt.legend() |
|
|
plt.xlabel('Epochs') |
|
|
plt.ylabel('PSNR') |
|
|
plt.grid(True) |
|
|
plt.savefig(self.get_path('test_{}.pdf'.format(d))) |
|
|
plt.close(fig) |
|
|
|
|
|
def begin_background(self): |
|
|
self.queue = Queue() |
|
|
|
|
|
def bg_target(queue): |
|
|
while True: |
|
|
if not queue.empty(): |
|
|
filename, tensor = queue.get() |
|
|
if filename is None: break |
|
|
cv2.imwrite(filename, cv2.cvtColor( (tensor.numpy()).astype(np.uint8), cv2.COLOR_RGB2BGR)) |
|
|
|
|
|
self.process = [ |
|
|
Process(target=bg_target, args=(self.queue,)) \ |
|
|
for _ in range(self.n_processes) |
|
|
] |
|
|
|
|
|
for p in self.process: p.start() |
|
|
|
|
|
def end_background(self): |
|
|
for _ in range(self.n_processes): self.queue.put((None, None)) |
|
|
while not self.queue.empty(): time.sleep(1) |
|
|
for p in self.process: p.join() |
|
|
|
|
|
def save_results(self, dataset, filename, save_list, scale): |
|
|
if self.args.save_results: |
|
|
filename = self.get_path( |
|
|
'results-{}'.format(dataset.dataset.name), |
|
|
|
|
|
'{}'.format(filename) |
|
|
) |
|
|
|
|
|
postfix = ('SMGARN', 'GT') |
|
|
for v, p in zip(save_list, postfix): |
|
|
normalized_sr = v[0].mul(255 / self.args.rgb_range) |
|
|
tensor_cpu_sr = normalized_sr.byte().permute(1, 2, 0).cpu() |
|
|
|
|
|
|
|
|
self.queue.put(('{}.jpg'.format(filename), tensor_cpu_sr)) |
|
|
|
|
|
|
|
|
def quantize(img, rgb_range): |
|
|
pixel_range = 255 / rgb_range |
|
|
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) |
|
|
|
|
|
|
|
|
def calc_psnr(sr, hr, scale, rgb_range, dataset=None): |
|
|
if hr.nelement() == 1: return 0 |
|
|
|
|
|
diff = (sr - hr) / rgb_range |
|
|
if dataset and dataset.dataset.benchmark: |
|
|
shave = scale |
|
|
if diff.size(1) > 1: |
|
|
gray_coeffs = [65.738, 129.057, 25.064] |
|
|
convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 |
|
|
diff = diff.mul(convert).sum(dim=1) |
|
|
else: |
|
|
shave = scale + 6 |
|
|
|
|
|
|
|
|
valid = diff[..., shave:-shave, shave:-shave] |
|
|
mse = valid.pow(2).mean() |
|
|
|
|
|
return -10 * math.log10(mse) |
|
|
|
|
|
|
|
|
def make_optimizer(args, target): |
|
|
''' |
|
|
make optimizer and scheduler together |
|
|
''' |
|
|
|
|
|
trainable = filter(lambda x: x.requires_grad, target.parameters()) |
|
|
kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay} |
|
|
|
|
|
if args.optimizer == 'SGD': |
|
|
optimizer_class = optim.SGD |
|
|
kwargs_optimizer['momentum'] = args.momentum |
|
|
elif args.optimizer == 'ADAM': |
|
|
optimizer_class = optim.Adam |
|
|
kwargs_optimizer['betas'] = args.betas |
|
|
kwargs_optimizer['eps'] = args.epsilon |
|
|
elif args.optimizer == 'RMSprop': |
|
|
optimizer_class = optim.RMSprop |
|
|
kwargs_optimizer['eps'] = args.epsilon |
|
|
|
|
|
|
|
|
milestones = list(map(lambda x: int(x), args.decay.split('-'))) |
|
|
kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} |
|
|
scheduler_class = lrs.MultiStepLR |
|
|
|
|
|
class CustomOptimizer(optimizer_class): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super(CustomOptimizer, self).__init__(*args, **kwargs) |
|
|
|
|
|
def _register_scheduler(self, scheduler_class, **kwargs): |
|
|
self.scheduler = scheduler_class(self, **kwargs) |
|
|
|
|
|
def save(self, save_dir): |
|
|
torch.save(self.state_dict(), self.get_dir(save_dir)) |
|
|
|
|
|
def load(self, load_dir, epoch=1): |
|
|
self.load_state_dict(torch.load(self.get_dir(load_dir))) |
|
|
if epoch > 1: |
|
|
for _ in range(epoch): self.scheduler.step() |
|
|
|
|
|
def get_dir(self, dir_path): |
|
|
return os.path.join(dir_path, 'optimizer.pt') |
|
|
|
|
|
def schedule(self): |
|
|
self.scheduler.step() |
|
|
|
|
|
def get_lr(self): |
|
|
return self.scheduler.get_lr()[0] |
|
|
|
|
|
def get_last_epoch(self): |
|
|
return self.scheduler.last_epoch |
|
|
|
|
|
optimizer = CustomOptimizer(trainable, **kwargs_optimizer) |
|
|
optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) |
|
|
return optimizer |