|
import argparse |
|
import json |
|
import os |
|
import torch.distributed as dist |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
import yaml |
|
import torch |
|
import torch.nn as nn |
|
from tqdm import tqdm |
|
from torch.utils.data import DataLoader |
|
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR |
|
|
|
import datasets |
|
import models |
|
import utils |
|
|
|
|
|
def make_data_loader(spec, tag='', local_rank=0): |
|
if spec is None: |
|
return None |
|
|
|
dataset = datasets.make(spec['dataset']) |
|
dataset = datasets.make(spec['wrapper'], args={'dataset': dataset}) |
|
if local_rank == 0: |
|
print('{} dataset: size={}'.format(tag, len(dataset))) |
|
for k, v in dataset[0].items(): |
|
if torch.is_tensor(v): |
|
print(' {}: shape={}'.format(k, v.shape)) |
|
elif isinstance(v, str): |
|
pass |
|
elif isinstance(v, dict): |
|
for k0, v0 in v.items(): |
|
if hasattr(v0, 'shape'): |
|
print(' {}: shape={}'.format(k0, v0.shape)) |
|
else: |
|
raise NotImplementedError |
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=(tag == 'train')) |
|
loader = torch.utils.data.DataLoader(dataset, |
|
batch_size=spec['batch_size'], |
|
num_workers=spec['num_workers'], |
|
pin_memory=True, |
|
sampler=sampler) |
|
return loader |
|
|
|
|
|
def make_data_loaders(config, local_rank): |
|
train_loader = make_data_loader(config.get('train_dataset'), tag='train', local_rank=local_rank) |
|
val_loader = make_data_loader(config.get('val_dataset'), tag='val', local_rank=local_rank) |
|
return train_loader, val_loader |
|
|
|
|
|
def prepare_training(config, local_rank): |
|
if config.get('resume') is not None: |
|
sv_file = torch.load(config['resume']) |
|
model = models.make(sv_file['model'], load_sd=True).cuda() |
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
optimizer = utils.make_optimizer( |
|
model.parameters(), sv_file['optimizer'], load_sd=True) |
|
epoch_start = sv_file['epoch'] + 1 |
|
if config.get('multi_step_lr') is None: |
|
lr_scheduler = None |
|
else: |
|
lr_scheduler = MultiStepLR(optimizer, **config['multi_step_lr']) |
|
for _ in range(epoch_start - 1): |
|
lr_scheduler.step() |
|
else: |
|
model = models.make(config['model']).cuda(local_rank) |
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) |
|
optimizer = utils.make_optimizer( |
|
model.parameters(), config['optimizer']) |
|
epoch_start = 1 |
|
lr_scheduler = config.get('lr_scheduler') |
|
lr_scheduler_name = lr_scheduler.pop('name') |
|
if 'MultiStepLR' == lr_scheduler_name: |
|
lr_scheduler = MultiStepLR(optimizer, **lr_scheduler) |
|
elif 'CosineAnnealingLR' == lr_scheduler_name: |
|
lr_scheduler = CosineAnnealingLR(optimizer, **lr_scheduler) |
|
elif 'CosineAnnealingWarmUpLR' == lr_scheduler_name: |
|
lr_scheduler = utils.warm_up_cosine_lr_scheduler(optimizer, **lr_scheduler) |
|
if local_rank == 0: |
|
print('model: #params={}'.format(utils.compute_num_params(model, text=True))) |
|
return model, optimizer, epoch_start, lr_scheduler |
|
|
|
def reduce_mean(tensor, nprocs): |
|
rt = tensor.clone() |
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
|
rt /= nprocs |
|
return rt |
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
def __init__(self, name, fmt=':f'): |
|
self.name = name |
|
self.fmt = fmt |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
def return_avg(self): |
|
return self.avg |
|
|
|
def __str__(self): |
|
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
|
return fmtstr.format(**self.__dict__) |
|
|
|
def train(train_loader, model, optimizer, local_rank): |
|
model = model.train() |
|
loss_fn = nn.L1Loss().cuda(local_rank) |
|
train_losses = AverageMeter('Loss', ':.4e') |
|
|
|
data_norm = config['data_norm'] |
|
t = data_norm['img'] |
|
img_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda(local_rank) |
|
img_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda(local_rank) |
|
t = data_norm['gt'] |
|
gt_sub = torch.FloatTensor(t['sub']).view(1, 1, -1).cuda(local_rank) |
|
gt_div = torch.FloatTensor(t['div']).view(1, 1, -1).cuda(local_rank) |
|
|
|
if local_rank == 0: |
|
pbar = tqdm(total=len(train_loader), desc='train', leave=False) |
|
|
|
for i, batch in enumerate(train_loader): |
|
if local_rank == 0: |
|
pbar.update(1) |
|
keys = list(batch.keys()) |
|
batch = batch[keys[torch.randint(0, len(keys), [])]] |
|
for k, v in batch.items(): |
|
if torch.is_tensor(v): |
|
batch[k] = v.cuda(local_rank, non_blocking=True) |
|
img = (batch['img'] - img_sub) / img_div |
|
gt = (batch['gt'] - gt_sub) / gt_div |
|
pred = model(img, gt.shape[-2:]) |
|
if isinstance(pred, tuple): |
|
loss = 0.2 * loss_fn(pred[0], gt) + loss_fn(pred[1], gt) |
|
elif isinstance(pred, list): |
|
losses = [loss_fn(x, gt) for x in pred] |
|
losses = [x * (idx + 1) for idx, x in enumerate(losses)] |
|
loss = sum(losses) / ((1 + len(losses)) * len(losses) / 2) |
|
else: |
|
loss = loss_fn(pred, gt) |
|
|
|
torch.distributed.barrier() |
|
reduced_loss = reduce_mean(loss, dist.get_world_size()) |
|
train_losses.update(reduced_loss.item(), img.size(0)) |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
if local_rank == 0: |
|
pbar.close() |
|
return train_losses.avg |
|
|
|
|
|
def eval_psnr(loader, class_names, model, local_rank, data_norm=None, eval_type=None, eval_bsize=None, verbose=False, crop_border=4): |
|
crop_border = int(crop_border) if crop_border else crop_border |
|
if local_rank == 0: |
|
print('crop border: ', crop_border) |
|
model = model.eval() |
|
|
|
if data_norm is None: |
|
data_norm = { |
|
'img': {'sub': [0], 'div': [1]}, |
|
'gt': {'sub': [0], 'div': [1]} |
|
} |
|
t = data_norm['img'] |
|
img_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda(local_rank) |
|
img_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda(local_rank) |
|
t = data_norm['gt'] |
|
gt_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda(local_rank) |
|
gt_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda(local_rank) |
|
|
|
if eval_type is None: |
|
metric_fn = [utils.calculate_psnr_pt, utils.calculate_ssim_pt] |
|
elif eval_type == 'psnr+ssim': |
|
metric_fn = [utils.calculate_psnr_pt, utils.calculate_ssim_pt] |
|
elif eval_type.startswith('div2k'): |
|
scale = int(eval_type.split('-')[1]) |
|
metric_fn = partial(utils.calc_psnr, dataset='div2k', scale=scale) |
|
elif eval_type.startswith('benchmark'): |
|
scale = int(eval_type.split('-')[1]) |
|
metric_fn = partial(utils.calc_psnr, dataset='benchmark', scale=scale) |
|
else: |
|
raise NotImplementedError |
|
|
|
val_res_psnr = AverageMeter('psnr', ':.4f') |
|
val_res_ssim = AverageMeter('ssim', ':.4f') |
|
|
|
if local_rank == 0: |
|
pbar = tqdm(total=len(loader), desc='val', leave=False) |
|
for batch in loader: |
|
if local_rank == 0: |
|
pbar.update(1) |
|
for k, v in batch.items(): |
|
if torch.is_tensor(v): |
|
batch[k] = v.cuda(local_rank, non_blocking=True) |
|
|
|
img = (batch['img'] - img_sub) / img_div |
|
with torch.no_grad(): |
|
pred = model(img, batch['gt'].shape[-2:]) |
|
if isinstance(pred, list): |
|
pred = pred[-1] |
|
pred = pred * gt_div + gt_sub |
|
|
|
res_psnr = metric_fn[0]( |
|
pred, |
|
batch['gt'], |
|
crop_border=crop_border |
|
).mean() |
|
res_ssim = metric_fn[1]( |
|
pred, |
|
batch['gt'], |
|
crop_border=crop_border |
|
).mean() |
|
|
|
torch.distributed.barrier() |
|
reduced_val_res_psnr = reduce_mean(res_psnr, dist.get_world_size()) |
|
reduced_val_res_ssim = reduce_mean(res_ssim, dist.get_world_size()) |
|
|
|
val_res_psnr.update(reduced_val_res_psnr.item(), img.size(0)) |
|
val_res_ssim.update(reduced_val_res_ssim.item(), img.size(0)) |
|
|
|
if verbose and local_rank == 0: |
|
pbar.set_description( |
|
'val psnr: {:.4f} ssim: {:.4f}'.format(val_res_psnr.avg, val_res_ssim.avg)) |
|
if local_rank == 0: |
|
pbar.close() |
|
return val_res_psnr.avg, val_res_ssim.avg |
|
|
|
|
|
def main(config, save_path): |
|
|
|
dist.init_process_group("nccl") |
|
rank = dist.get_rank() |
|
local_rank = int(os.environ["LOCAL_RANK"]) |
|
world_size = dist.get_world_size() |
|
print(f'rank: {rank} local_rank: {local_rank} world_size: {world_size}') |
|
|
|
if local_rank == 0: |
|
log, writer = utils.set_save_path(save_path) |
|
with open(os.path.join(save_path, 'config.yaml'), 'w') as f: |
|
yaml.dump(config, f, sort_keys=False) |
|
|
|
train_loader, val_loader = make_data_loaders(config, local_rank) |
|
if config.get('data_norm') is None: |
|
config['data_norm'] = { |
|
'img': {'sub': [0], 'div': [1]}, |
|
'gt': {'sub': [0], 'div': [1]} |
|
} |
|
|
|
model, optimizer, epoch_start, lr_scheduler = prepare_training(config, local_rank) |
|
|
|
epoch_max = config['epoch_max'] |
|
epoch_val_interval = config.get('epoch_val_interval') |
|
epoch_save_interval = config.get('epoch_save_interval') |
|
max_val_v = -1e18 |
|
|
|
timer = utils.Timer() |
|
|
|
for epoch in range(epoch_start, epoch_max + 1): |
|
t_epoch_start = timer.t() |
|
train_loader.sampler.set_epoch(epoch) |
|
|
|
train_loss = train(train_loader, model, optimizer, local_rank) |
|
if lr_scheduler is not None: |
|
lr_scheduler.step() |
|
|
|
if rank == 0: |
|
log_info = ['epoch {}/{}'.format(epoch, epoch_max)] |
|
log_info.append('train: loss={:.4f}'.format(train_loss)) |
|
writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch) |
|
writer.add_scalars('loss', {'train': train_loss}, epoch) |
|
|
|
model_ = model.module |
|
model_spec = config['model'] |
|
model_spec['sd'] = model_.state_dict() |
|
optimizer_spec = config['optimizer'] |
|
optimizer_spec['sd'] = optimizer.state_dict() |
|
sv_file = { |
|
'model': model_spec, |
|
'optimizer': optimizer_spec, |
|
'epoch': epoch |
|
} |
|
if rank == 0: |
|
torch.save(sv_file, os.path.join(save_path, 'epoch-last.pth')) |
|
|
|
if (epoch_save_interval is not None) and (epoch % epoch_save_interval == 0): |
|
if rank == 0: |
|
torch.save(sv_file, os.path.join(save_path, 'epoch-{}.pth'.format(epoch))) |
|
|
|
if (epoch_val_interval is not None) and (epoch % epoch_val_interval == 0): |
|
file_names = json.load(open(config['val_dataset']['dataset']['args']['split_file']))['test'] |
|
class_names = list(set([os.path.basename(os.path.dirname(x)) for x in file_names])) |
|
|
|
val_res_psnr, val_res_ssim = eval_psnr(val_loader, class_names, model_, local_rank, |
|
data_norm=config['data_norm'], |
|
eval_type=config.get('eval_type'), |
|
eval_bsize=config.get('eval_bsize'), |
|
crop_border=4) |
|
if rank == 0: |
|
log_info.append('val: psnr={:.4f}'.format(val_res_psnr)) |
|
writer.add_scalars('psnr', {'val': val_res_psnr}, epoch) |
|
if val_res_psnr > max_val_v: |
|
max_val_v = val_res_psnr |
|
if rank == 0: |
|
torch.save(sv_file, os.path.join(save_path, 'epoch-best.pth')) |
|
|
|
t = timer.t() |
|
if rank == 0: |
|
prog = (epoch - epoch_start + 1) / (epoch_max - epoch_start + 1) |
|
t_epoch = utils.time_text(t - t_epoch_start) |
|
t_elapsed, t_all = utils.time_text(t), utils.time_text(t / prog) |
|
log_info.append('{} {}/{}'.format(t_epoch, t_elapsed, t_all)) |
|
log(', '.join(log_info)) |
|
writer.flush() |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--config', default='configs/train_1x-5x_INR_funsr.yaml') |
|
parser.add_argument('--name', default='EXP20221216_11') |
|
parser.add_argument('--tag', default=None) |
|
args = parser.parse_args() |
|
|
|
with open(args.config, 'r') as f: |
|
config = yaml.load(f, Loader=yaml.FullLoader) |
|
print('config loaded.') |
|
|
|
save_name = args.name |
|
if save_name is None: |
|
save_name = '_' + args.config.split('/')[-1][:-len('.yaml')] |
|
if args.tag is not None: |
|
save_name += '_' + args.tag |
|
save_path = os.path.join('./checkpoints', save_name) |
|
main(config, save_path) |
|
|