|
import os
|
|
import math
|
|
from decimal import Decimal
|
|
|
|
import utility
|
|
|
|
import torch
|
|
import torch.nn.utils as utils
|
|
from tqdm import tqdm
|
|
|
|
import data
|
|
|
|
import torch.cuda.amp as amp
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import torchvision
|
|
|
|
import numpy as np
|
|
|
|
class Trainer():
|
|
def __init__(self, args, loader, my_model, my_loss, ckp):
|
|
self.args = args
|
|
self.scale = args.scale
|
|
|
|
self.ckp = ckp
|
|
self.loader_train = loader.loader_train
|
|
self.loader_test = loader.loader_test
|
|
self.model = my_model
|
|
self.loss = my_loss
|
|
self.optimizer = utility.make_optimizer(args, self.model)
|
|
|
|
if self.args.load != '':
|
|
self.optimizer.load(ckp.dir, epoch=len(ckp.log))
|
|
|
|
self.error_last = 1e8
|
|
self.scaler=amp.GradScaler(
|
|
enabled=args.amp
|
|
)
|
|
self.writter=None
|
|
self.recurrence=args.recurrence
|
|
if args.recurrence>1:
|
|
self.writter=SummaryWriter(f"runs/{args.save}")
|
|
|
|
|
|
def train(self):
|
|
self.loss.step()
|
|
epoch = self.optimizer.get_last_epoch() + 1
|
|
|
|
if self.args.model=="RECURSIONNET":
|
|
if epoch>50:
|
|
self.args.patch_size=64
|
|
if epoch>100:
|
|
self.args.patch_size=512
|
|
self.args.batch_size=1
|
|
_loader=data.Data(self.args)
|
|
self.loader_train=_loader.loader_train
|
|
self.loader_test=_loader.loader_test
|
|
|
|
lr = self.optimizer.get_lr()
|
|
|
|
self.ckp.write_log(
|
|
'[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
|
|
)
|
|
self.loss.start_log()
|
|
self.model.train()
|
|
|
|
timer_data, timer_model = utility.timer(), utility.timer()
|
|
|
|
self.loader_train.dataset.set_scale(0)
|
|
total=len(self.loader_train)
|
|
|
|
buffer=([0.0]*self.recurrence) if self.recurrence>1 else []
|
|
for batch, (lr, hr, _,) in enumerate(self.loader_train):
|
|
lr, hr = self.prepare(lr, hr)
|
|
timer_data.hold()
|
|
timer_model.tic()
|
|
|
|
self.optimizer.zero_grad()
|
|
with amp.autocast(self.args.amp):
|
|
sr = self.model(lr, 0)
|
|
if isinstance(sr,list) and len(sr)==1:
|
|
sr=sr[0]
|
|
loss = self.loss(sr, hr)
|
|
self.scaler.scale(loss).backward()
|
|
if self.args.gclip > 0:
|
|
self.scaler.unscale_(self.optimizer)
|
|
utils.clip_grad_value_(
|
|
self.model.parameters(),
|
|
self.args.gclip
|
|
)
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
for i in range(len(buffer)):
|
|
buffer[i]+=self.loss.buffer[i]
|
|
|
|
|
|
timer_model.hold()
|
|
|
|
if (batch + 1) % self.args.print_every == 0:
|
|
self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
|
|
(batch + 1) * self.args.batch_size,
|
|
len(self.loader_train.dataset),
|
|
self.loss.display_loss(batch),
|
|
timer_model.release(),
|
|
timer_data.release()))
|
|
|
|
timer_data.tic()
|
|
if self.writter:
|
|
for i in range(self.recurrence):
|
|
grid=torchvision.utils.make_grid(sr[i])
|
|
self.writter.add_image(f"Output{i}",grid,epoch)
|
|
self.writter.add_scalar(f"Loss{i}",buffer[i]/total,epoch)
|
|
self.writter.add_image("Input",torchvision.utils.make_grid(lr),epoch)
|
|
self.writter.add_image("Target",torchvision.utils.make_grid(hr),epoch)
|
|
self.loss.end_log(len(self.loader_train))
|
|
self.error_last = self.loss.log[-1, -1]
|
|
self.optimizer.schedule()
|
|
|
|
def test(self):
|
|
torch.set_grad_enabled(False)
|
|
|
|
epoch = self.optimizer.get_last_epoch()
|
|
self.ckp.write_log('\nEvaluation:')
|
|
self.ckp.add_log(
|
|
torch.zeros(1, len(self.loader_test), len(self.scale))
|
|
)
|
|
self.model.eval()
|
|
|
|
timer_test = utility.timer()
|
|
if self.args.save_results: self.ckp.begin_background()
|
|
for idx_data, d in enumerate(self.loader_test):
|
|
for idx_scale, scale in enumerate(self.scale):
|
|
d.dataset.set_scale(idx_scale)
|
|
for lr, hr, filename in tqdm(d, ncols=80):
|
|
lr, hr = self.prepare(lr, hr)
|
|
with amp.autocast(self.args.amp):
|
|
sr = self.model(lr, idx_scale)
|
|
if isinstance(sr,list):
|
|
sr=sr[-1]
|
|
sr = utility.quantize(sr, self.args.rgb_range)
|
|
|
|
save_list = [sr]
|
|
self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
|
|
sr, hr, scale, self.args.rgb_range, dataset=d
|
|
)
|
|
if self.args.save_gt:
|
|
save_list.extend([lr, hr])
|
|
|
|
if self.args.save_results:
|
|
self.ckp.save_results(d, filename[0], save_list, scale)
|
|
|
|
self.ckp.log[-1, idx_data, idx_scale] /= len(d)
|
|
best = self.ckp.log.max(0)
|
|
self.ckp.write_log(
|
|
'[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
|
|
d.dataset.name,
|
|
scale,
|
|
self.ckp.log[-1, idx_data, idx_scale],
|
|
best[0][idx_data, idx_scale],
|
|
best[1][idx_data, idx_scale] + 1
|
|
)
|
|
)
|
|
self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
|
|
self.ckp.write_log('Saving...')
|
|
|
|
if self.args.save_results:
|
|
self.ckp.end_background()
|
|
|
|
if not self.args.test_only:
|
|
self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))
|
|
|
|
self.ckp.write_log(
|
|
'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
|
|
)
|
|
|
|
torch.set_grad_enabled(True)
|
|
|
|
def prepare(self, *args):
|
|
device = torch.device('cpu' if self.args.cpu else 'cuda')
|
|
def _prepare(tensor):
|
|
if self.args.precision == 'half': tensor = tensor.half()
|
|
return tensor.to(device)
|
|
|
|
return [_prepare(a) for a in args]
|
|
|
|
def terminate(self):
|
|
if self.args.test_only:
|
|
self.test()
|
|
return True
|
|
else:
|
|
epoch = self.optimizer.get_last_epoch() + 1
|
|
return epoch >= self.args.epochs
|
|
|
|
|