lambdanet / SR /code /trainer.py
hyliu's picture
Upload folder using huggingface_hub
d6ec83b verified
import os
import math
from decimal import Decimal
import utility
import torch
import torch.nn.utils as utils
from tqdm import tqdm
import torch.cuda.amp as amp
from torch.utils.tensorboard import SummaryWriter
import torchvision
import numpy as np
class Trainer():
def __init__(self, args, loader, my_model, my_loss, ckp):
self.args = args
self.scale = args.scale
self.ckp = ckp
self.loader_train = loader.loader_train
self.loader_test = loader.loader_test
self.model = my_model
self.loss = my_loss
self.optimizer = utility.make_optimizer(args, self.model)
if self.args.load != '':
self.optimizer.load(ckp.dir, epoch=len(ckp.log))
self.error_last = 1e8
self.scaler=amp.GradScaler(
enabled=args.amp
)
self.writter=None
self.recurrence=args.recurrence
if args.recurrence>1:
self.writter=SummaryWriter(f"runs/{args.save}")
def train(self):
self.loss.step()
epoch = self.optimizer.get_last_epoch() + 1
lr = self.optimizer.get_lr()
self.ckp.write_log(
'[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
)
self.loss.start_log()
self.model.train()
timer_data, timer_model = utility.timer(), utility.timer()
# TEMP
self.loader_train.dataset.set_scale(0)
total=len(self.loader_train)
buffer=[0.0]*self.recurrence
# torch.autograd.set_detect_anomaly(True)
for batch, (lr, hr, _,) in enumerate(self.loader_train):
lr, hr = self.prepare(lr, hr)
# print(lr.min(),lr.max(), hr.min(),hr.max())
# exit(0)
timer_data.hold()
timer_model.tic()
self.optimizer.zero_grad()
with amp.autocast(self.args.amp):
sr = self.model(lr, 0)
if len(sr)==1:
sr=sr[0]
# loss,buffer_lst=sequence_loss(sr,hr)
loss = self.loss(sr, hr)
self.scaler.scale(loss).backward()
if self.args.gclip > 0:
self.scaler.unscale_(self.optimizer)
utils.clip_grad_value_(
self.model.parameters(),
self.args.gclip
)
self.scaler.step(self.optimizer)
self.scaler.update()
for i in range(self.recurrence):
buffer[i]+=self.loss.buffer[i]
# self.optimizer.step()
timer_model.hold()
if (batch + 1) % self.args.print_every == 0:
self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
(batch + 1) * self.args.batch_size,
len(self.loader_train.dataset),
self.loss.display_loss(batch),
timer_model.release(),
timer_data.release()))
timer_data.tic()
if self.writter:
for i in range(self.recurrence):
grid=torchvision.utils.make_grid(sr[i])
self.writter.add_image(f"Output{i}",grid,epoch)
self.writter.add_scalar(f"Loss{i}",buffer[i]/total,epoch)
self.writter.add_image("Input",torchvision.utils.make_grid(lr),epoch)
self.writter.add_image("Target",torchvision.utils.make_grid(hr),epoch)
self.loss.end_log(len(self.loader_train))
self.error_last = self.loss.log[-1, -1]
self.optimizer.schedule()
def test(self):
torch.set_grad_enabled(False)
epoch = self.optimizer.get_last_epoch()
self.ckp.write_log('\nEvaluation:')
self.ckp.add_log(
torch.zeros(1, len(self.loader_test), len(self.scale))
)
self.model.eval()
timer_test = utility.timer()
if self.args.save_results: self.ckp.begin_background()
for idx_data, d in enumerate(self.loader_test):
for idx_scale, scale in enumerate(self.scale):
d.dataset.set_scale(idx_scale)
for lr, hr, filename in tqdm(d, ncols=80):
lr, hr = self.prepare(lr, hr)
with amp.autocast(self.args.amp):
sr = self.model(lr, idx_scale)
if isinstance(sr,list):
sr=sr[-1]
sr = utility.quantize(sr, self.args.rgb_range)
save_list = [sr]
self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
sr, hr, scale, self.args.rgb_range, dataset=d
)
if self.args.save_gt:
save_list.extend([lr, hr])
if self.args.save_results:
self.ckp.save_results(d, filename[0], save_list, scale)
self.ckp.log[-1, idx_data, idx_scale] /= len(d)
best = self.ckp.log.max(0)
self.ckp.write_log(
'[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
d.dataset.name,
scale,
self.ckp.log[-1, idx_data, idx_scale],
best[0][idx_data, idx_scale],
best[1][idx_data, idx_scale] + 1
)
)
self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
self.ckp.write_log('Saving...')
# torch.cuda.empty_cache()
if self.args.save_results:
self.ckp.end_background()
if not self.args.test_only:
self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))
self.ckp.write_log(
'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
)
torch.set_grad_enabled(True)
def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda')
def _prepare(tensor):
if self.args.precision == 'half': tensor = tensor.half()
return tensor.to(device)
return [_prepare(a) for a in args]
def terminate(self):
if self.args.test_only:
self.test()
return True
else:
epoch = self.optimizer.get_last_epoch() + 1
return epoch >= self.args.epochs