FunSR / train_inr_funsr_ddp.py
KyanChen's picture
add
02c5426
raw
history blame
13.4 kB
import argparse
import json
import os
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import yaml
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR
import datasets
import models
import utils
def make_data_loader(spec, tag='', local_rank=0):
if spec is None:
return None
dataset = datasets.make(spec['dataset'])
dataset = datasets.make(spec['wrapper'], args={'dataset': dataset})
if local_rank == 0:
print('{} dataset: size={}'.format(tag, len(dataset)))
for k, v in dataset[0].items():
if torch.is_tensor(v):
print(' {}: shape={}'.format(k, v.shape))
elif isinstance(v, str):
pass
elif isinstance(v, dict):
for k0, v0 in v.items():
if hasattr(v0, 'shape'):
print(' {}: shape={}'.format(k0, v0.shape))
else:
raise NotImplementedError
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=(tag == 'train'))
loader = torch.utils.data.DataLoader(dataset,
batch_size=spec['batch_size'],
num_workers=spec['num_workers'],
pin_memory=True,
sampler=sampler)
return loader
def make_data_loaders(config, local_rank):
train_loader = make_data_loader(config.get('train_dataset'), tag='train', local_rank=local_rank)
val_loader = make_data_loader(config.get('val_dataset'), tag='val', local_rank=local_rank)
return train_loader, val_loader
def prepare_training(config, local_rank):
if config.get('resume') is not None:
sv_file = torch.load(config['resume'])
model = models.make(sv_file['model'], load_sd=True).cuda()
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
optimizer = utils.make_optimizer(
model.parameters(), sv_file['optimizer'], load_sd=True)
epoch_start = sv_file['epoch'] + 1
if config.get('multi_step_lr') is None:
lr_scheduler = None
else:
lr_scheduler = MultiStepLR(optimizer, **config['multi_step_lr'])
for _ in range(epoch_start - 1):
lr_scheduler.step()
else:
model = models.make(config['model']).cuda(local_rank)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
optimizer = utils.make_optimizer(
model.parameters(), config['optimizer'])
epoch_start = 1
lr_scheduler = config.get('lr_scheduler')
lr_scheduler_name = lr_scheduler.pop('name')
if 'MultiStepLR' == lr_scheduler_name:
lr_scheduler = MultiStepLR(optimizer, **lr_scheduler)
elif 'CosineAnnealingLR' == lr_scheduler_name:
lr_scheduler = CosineAnnealingLR(optimizer, **lr_scheduler)
elif 'CosineAnnealingWarmUpLR' == lr_scheduler_name:
lr_scheduler = utils.warm_up_cosine_lr_scheduler(optimizer, **lr_scheduler)
if local_rank == 0:
print('model: #params={}'.format(utils.compute_num_params(model, text=True)))
return model, optimizer, epoch_start, lr_scheduler
def reduce_mean(tensor, nprocs):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= nprocs
return rt
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def return_avg(self):
return self.avg
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def train(train_loader, model, optimizer, local_rank):
model = model.train()
loss_fn = nn.L1Loss().cuda(local_rank)
train_losses = AverageMeter('Loss', ':.4e')
data_norm = config['data_norm']
t = data_norm['img']
img_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda(local_rank)
img_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda(local_rank)
t = data_norm['gt']
gt_sub = torch.FloatTensor(t['sub']).view(1, 1, -1).cuda(local_rank)
gt_div = torch.FloatTensor(t['div']).view(1, 1, -1).cuda(local_rank)
if local_rank == 0:
pbar = tqdm(total=len(train_loader), desc='train', leave=False)
for i, batch in enumerate(train_loader):
if local_rank == 0:
pbar.update(1)
keys = list(batch.keys())
batch = batch[keys[torch.randint(0, len(keys), [])]]
for k, v in batch.items():
if torch.is_tensor(v):
batch[k] = v.cuda(local_rank, non_blocking=True)
img = (batch['img'] - img_sub) / img_div
gt = (batch['gt'] - gt_sub) / gt_div
pred = model(img, gt.shape[-2:])
if isinstance(pred, tuple):
loss = 0.2 * loss_fn(pred[0], gt) + loss_fn(pred[1], gt)
elif isinstance(pred, list):
losses = [loss_fn(x, gt) for x in pred]
losses = [x * (idx + 1) for idx, x in enumerate(losses)]
loss = sum(losses) / ((1 + len(losses)) * len(losses) / 2)
else:
loss = loss_fn(pred, gt)
torch.distributed.barrier()
reduced_loss = reduce_mean(loss, dist.get_world_size())
train_losses.update(reduced_loss.item(), img.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if local_rank == 0:
pbar.close()
return train_losses.avg
def eval_psnr(loader, class_names, model, local_rank, data_norm=None, eval_type=None, eval_bsize=None, verbose=False, crop_border=4):
crop_border = int(crop_border) if crop_border else crop_border
if local_rank == 0:
print('crop border: ', crop_border)
model = model.eval()
if data_norm is None:
data_norm = {
'img': {'sub': [0], 'div': [1]},
'gt': {'sub': [0], 'div': [1]}
}
t = data_norm['img']
img_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda(local_rank)
img_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda(local_rank)
t = data_norm['gt']
gt_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda(local_rank)
gt_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda(local_rank)
if eval_type is None:
metric_fn = [utils.calculate_psnr_pt, utils.calculate_ssim_pt]
elif eval_type == 'psnr+ssim':
metric_fn = [utils.calculate_psnr_pt, utils.calculate_ssim_pt]
elif eval_type.startswith('div2k'):
scale = int(eval_type.split('-')[1])
metric_fn = partial(utils.calc_psnr, dataset='div2k', scale=scale)
elif eval_type.startswith('benchmark'):
scale = int(eval_type.split('-')[1])
metric_fn = partial(utils.calc_psnr, dataset='benchmark', scale=scale)
else:
raise NotImplementedError
val_res_psnr = AverageMeter('psnr', ':.4f')
val_res_ssim = AverageMeter('ssim', ':.4f')
if local_rank == 0:
pbar = tqdm(total=len(loader), desc='val', leave=False)
for batch in loader:
if local_rank == 0:
pbar.update(1)
for k, v in batch.items():
if torch.is_tensor(v):
batch[k] = v.cuda(local_rank, non_blocking=True)
img = (batch['img'] - img_sub) / img_div
with torch.no_grad():
pred = model(img, batch['gt'].shape[-2:])
if isinstance(pred, list):
pred = pred[-1]
pred = pred * gt_div + gt_sub
res_psnr = metric_fn[0](
pred,
batch['gt'],
crop_border=crop_border
).mean()
res_ssim = metric_fn[1](
pred,
batch['gt'],
crop_border=crop_border
).mean()
torch.distributed.barrier()
reduced_val_res_psnr = reduce_mean(res_psnr, dist.get_world_size())
reduced_val_res_ssim = reduce_mean(res_ssim, dist.get_world_size())
val_res_psnr.update(reduced_val_res_psnr.item(), img.size(0))
val_res_ssim.update(reduced_val_res_ssim.item(), img.size(0))
if verbose and local_rank == 0:
pbar.set_description(
'val psnr: {:.4f} ssim: {:.4f}'.format(val_res_psnr.avg, val_res_ssim.avg))
if local_rank == 0:
pbar.close()
return val_res_psnr.avg, val_res_ssim.avg
def main(config, save_path):
# torch.backends.cudnn.benchmark = True
dist.init_process_group("nccl")
rank = dist.get_rank()
local_rank = int(os.environ["LOCAL_RANK"])
world_size = dist.get_world_size()
print(f'rank: {rank} local_rank: {local_rank} world_size: {world_size}')
# print(f'local_rank: {torch.distributed.local_rank()}')
if local_rank == 0:
log, writer = utils.set_save_path(save_path)
with open(os.path.join(save_path, 'config.yaml'), 'w') as f:
yaml.dump(config, f, sort_keys=False)
train_loader, val_loader = make_data_loaders(config, local_rank)
if config.get('data_norm') is None:
config['data_norm'] = {
'img': {'sub': [0], 'div': [1]},
'gt': {'sub': [0], 'div': [1]}
}
model, optimizer, epoch_start, lr_scheduler = prepare_training(config, local_rank)
epoch_max = config['epoch_max']
epoch_val_interval = config.get('epoch_val_interval')
epoch_save_interval = config.get('epoch_save_interval')
max_val_v = -1e18
timer = utils.Timer()
for epoch in range(epoch_start, epoch_max + 1):
t_epoch_start = timer.t()
train_loader.sampler.set_epoch(epoch)
train_loss = train(train_loader, model, optimizer, local_rank)
if lr_scheduler is not None:
lr_scheduler.step()
if rank == 0:
log_info = ['epoch {}/{}'.format(epoch, epoch_max)]
log_info.append('train: loss={:.4f}'.format(train_loss))
writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
writer.add_scalars('loss', {'train': train_loss}, epoch)
model_ = model.module
model_spec = config['model']
model_spec['sd'] = model_.state_dict()
optimizer_spec = config['optimizer']
optimizer_spec['sd'] = optimizer.state_dict()
sv_file = {
'model': model_spec,
'optimizer': optimizer_spec,
'epoch': epoch
}
if rank == 0:
torch.save(sv_file, os.path.join(save_path, 'epoch-last.pth'))
if (epoch_save_interval is not None) and (epoch % epoch_save_interval == 0):
if rank == 0:
torch.save(sv_file, os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))
if (epoch_val_interval is not None) and (epoch % epoch_val_interval == 0):
file_names = json.load(open(config['val_dataset']['dataset']['args']['split_file']))['test']
class_names = list(set([os.path.basename(os.path.dirname(x)) for x in file_names]))
val_res_psnr, val_res_ssim = eval_psnr(val_loader, class_names, model_, local_rank,
data_norm=config['data_norm'],
eval_type=config.get('eval_type'),
eval_bsize=config.get('eval_bsize'),
crop_border=4)
if rank == 0:
log_info.append('val: psnr={:.4f}'.format(val_res_psnr))
writer.add_scalars('psnr', {'val': val_res_psnr}, epoch)
if val_res_psnr > max_val_v:
max_val_v = val_res_psnr
if rank == 0:
torch.save(sv_file, os.path.join(save_path, 'epoch-best.pth'))
t = timer.t()
if rank == 0:
prog = (epoch - epoch_start + 1) / (epoch_max - epoch_start + 1)
t_epoch = utils.time_text(t - t_epoch_start)
t_elapsed, t_all = utils.time_text(t), utils.time_text(t / prog)
log_info.append('{} {}/{}'.format(t_epoch, t_elapsed, t_all))
log(', '.join(log_info))
writer.flush()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='configs/train_1x-5x_INR_funsr.yaml')
parser.add_argument('--name', default='EXP20221216_11')
parser.add_argument('--tag', default=None)
args = parser.parse_args()
with open(args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
print('config loaded.')
save_name = args.name
if save_name is None:
save_name = '_' + args.config.split('/')[-1][:-len('.yaml')]
if args.tag is not None:
save_name += '_' + args.tag
save_path = os.path.join('./checkpoints', save_name)
main(config, save_path)