lambdanet / CAR /code /trainer.py
hyliu's picture
Upload folder using huggingface_hub
3ef0208 verified
import os
import math
from decimal import Decimal
import utility
import torch
import torch.nn.utils as utils
from tqdm import tqdm
import data
import torch.cuda.amp as amp
from torch.utils.tensorboard import SummaryWriter
import torchvision
import numpy as np
class Trainer():
def __init__(self, args, loader, my_model, my_loss, ckp):
self.args = args
self.scale = args.scale
self.ckp = ckp
self.loader_train = loader.loader_train
self.loader_test = loader.loader_test
self.model = my_model
self.loss = my_loss
self.optimizer = utility.make_optimizer(args, self.model)
if self.args.load != '':
self.optimizer.load(ckp.dir, epoch=len(ckp.log))
self.error_last = 1e8
self.scaler=amp.GradScaler(
enabled=args.amp
)
self.writter=None
self.recurrence=args.recurrence
if args.recurrence>1:
self.writter=SummaryWriter(f"runs/{args.save}")
def train(self):
self.loss.step()
epoch = self.optimizer.get_last_epoch() + 1
if self.args.model=="RECURSIONNET":
if epoch>50:
self.args.patch_size=64
if epoch>100:
self.args.patch_size=512
self.args.batch_size=1
_loader=data.Data(self.args)
self.loader_train=_loader.loader_train
self.loader_test=_loader.loader_test
lr = self.optimizer.get_lr()
self.ckp.write_log(
'[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
)
self.loss.start_log()
self.model.train()
timer_data, timer_model = utility.timer(), utility.timer()
# TEMP
self.loader_train.dataset.set_scale(0)
total=len(self.loader_train)
buffer=([0.0]*self.recurrence) if self.recurrence>1 else []
for batch, (lr, hr, _,) in enumerate(self.loader_train):
lr, hr = self.prepare(lr, hr)
timer_data.hold()
timer_model.tic()
self.optimizer.zero_grad()
with amp.autocast(self.args.amp):
sr = self.model(lr, 0)
if isinstance(sr,list) and len(sr)==1:
sr=sr[0]
loss = self.loss(sr, hr)
self.scaler.scale(loss).backward()
if self.args.gclip > 0:
self.scaler.unscale_(self.optimizer)
utils.clip_grad_value_(
self.model.parameters(),
self.args.gclip
)
self.scaler.step(self.optimizer)
self.scaler.update()
for i in range(len(buffer)):
buffer[i]+=self.loss.buffer[i]
# self.optimizer.step()
timer_model.hold()
if (batch + 1) % self.args.print_every == 0:
self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
(batch + 1) * self.args.batch_size,
len(self.loader_train.dataset),
self.loss.display_loss(batch),
timer_model.release(),
timer_data.release()))
timer_data.tic()
if self.writter:
for i in range(self.recurrence):
grid=torchvision.utils.make_grid(sr[i])
self.writter.add_image(f"Output{i}",grid,epoch)
self.writter.add_scalar(f"Loss{i}",buffer[i]/total,epoch)
self.writter.add_image("Input",torchvision.utils.make_grid(lr),epoch)
self.writter.add_image("Target",torchvision.utils.make_grid(hr),epoch)
self.loss.end_log(len(self.loader_train))
self.error_last = self.loss.log[-1, -1]
self.optimizer.schedule()
def test(self):
torch.set_grad_enabled(False)
epoch = self.optimizer.get_last_epoch()
self.ckp.write_log('\nEvaluation:')
self.ckp.add_log(
torch.zeros(1, len(self.loader_test), len(self.scale))
)
self.model.eval()
timer_test = utility.timer()
if self.args.save_results: self.ckp.begin_background()
for idx_data, d in enumerate(self.loader_test):
for idx_scale, scale in enumerate(self.scale):
d.dataset.set_scale(idx_scale)
for lr, hr, filename in tqdm(d, ncols=80):
lr, hr = self.prepare(lr, hr)
with amp.autocast(self.args.amp):
sr = self.model(lr, idx_scale)
if isinstance(sr,list):
sr=sr[-1]
sr = utility.quantize(sr, self.args.rgb_range)
save_list = [sr]
self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
sr, hr, scale, self.args.rgb_range, dataset=d
)
if self.args.save_gt:
save_list.extend([lr, hr])
if self.args.save_results:
self.ckp.save_results(d, filename[0], save_list, scale)
self.ckp.log[-1, idx_data, idx_scale] /= len(d)
best = self.ckp.log.max(0)
self.ckp.write_log(
'[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
d.dataset.name,
scale,
self.ckp.log[-1, idx_data, idx_scale],
best[0][idx_data, idx_scale],
best[1][idx_data, idx_scale] + 1
)
)
self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
self.ckp.write_log('Saving...')
# torch.cuda.empty_cache()
if self.args.save_results:
self.ckp.end_background()
if not self.args.test_only:
self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))
self.ckp.write_log(
'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
)
torch.set_grad_enabled(True)
def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda')
def _prepare(tensor):
if self.args.precision == 'half': tensor = tensor.half()
return tensor.to(device)
return [_prepare(a) for a in args]
def terminate(self):
if self.args.test_only:
self.test()
return True
else:
epoch = self.optimizer.get_last_epoch() + 1
return epoch >= self.args.epochs