kxhit
update
5f093a6
raw
history blame
12.1 kB
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
# --------------------------------------------------------
# Main function for training one epoch or testing
# --------------------------------------------------------
import math
import sys
from typing import Iterable
import numpy as np
import torch
import torchvision
from utils import misc as misc
def split_prediction_conf(predictions, with_conf=False):
if not with_conf:
return predictions, None
conf = predictions[:,-1:,:,:]
predictions = predictions[:,:-1,:,:]
return predictions, conf
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, metrics: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler,
log_writer=None, print_freq = 20,
args=None):
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
accum_iter = args.accum_iter
optimizer.zero_grad()
details = {}
if log_writer is not None:
print('log_dir: {}'.format(log_writer.log_dir))
if args.img_per_epoch:
iter_per_epoch = args.img_per_epoch // args.batch_size + int(args.img_per_epoch % args.batch_size > 0)
assert len(data_loader) >= iter_per_epoch, 'Dataset is too small for so many iterations'
len_data_loader = iter_per_epoch
else:
len_data_loader, iter_per_epoch = len(data_loader), None
for data_iter_step, (image1, image2, gt, pairname) in enumerate(metric_logger.log_every(data_loader, print_freq, header, max_iter=iter_per_epoch)):
image1 = image1.to(device, non_blocking=True)
image2 = image2.to(device, non_blocking=True)
gt = gt.to(device, non_blocking=True)
# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0:
misc.adjust_learning_rate(optimizer, data_iter_step / len_data_loader + epoch, args)
with torch.cuda.amp.autocast(enabled=bool(args.amp)):
prediction = model(image1, image2)
prediction, conf = split_prediction_conf(prediction, criterion.with_conf)
batch_metrics = metrics(prediction.detach(), gt)
loss = criterion(prediction, gt) if conf is None else criterion(prediction, gt, conf)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
loss /= accum_iter
loss_scaler(loss, optimizer, parameters=model.parameters(),
update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
for k,v in batch_metrics.items():
metric_logger.update(**{k: v.item()})
lr = optimizer.param_groups[0]["lr"]
metric_logger.update(lr=lr)
#if args.dsitributed: loss_value_reduce = misc.all_reduce_mean(loss_value)
time_to_log = ((data_iter_step + 1) % (args.tboard_log_step * accum_iter) == 0 or data_iter_step == len_data_loader-1)
loss_value_reduce = misc.all_reduce_mean(loss_value)
if log_writer is not None and time_to_log:
epoch_1000x = int((data_iter_step / len_data_loader + epoch) * 1000)
# We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes.
log_writer.add_scalar('train/loss', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', lr, epoch_1000x)
for k,v in batch_metrics.items():
log_writer.add_scalar('train/'+k, v.item(), epoch_1000x)
# gather the stats from all processes
#if args.distributed: metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def validate_one_epoch(model: torch.nn.Module,
criterion: torch.nn.Module,
metrics: torch.nn.Module,
data_loaders: list[Iterable],
device: torch.device,
epoch: int,
log_writer=None,
args=None):
model.eval()
metric_loggers = []
header = 'Epoch: [{}]'.format(epoch)
print_freq = 20
conf_mode = args.tile_conf_mode
crop = args.crop
if log_writer is not None:
print('log_dir: {}'.format(log_writer.log_dir))
results = {}
dnames = []
image1, image2, gt, prediction = None, None, None, None
for didx, data_loader in enumerate(data_loaders):
dname = str(data_loader.dataset)
dnames.append(dname)
metric_loggers.append(misc.MetricLogger(delimiter=" "))
for data_iter_step, (image1, image2, gt, pairname) in enumerate(metric_loggers[didx].log_every(data_loader, print_freq, header)):
image1 = image1.to(device, non_blocking=True)
image2 = image2.to(device, non_blocking=True)
gt = gt.to(device, non_blocking=True)
if dname.startswith('Spring'):
assert gt.size(2)==image1.size(2)*2 and gt.size(3)==image1.size(3)*2
gt = (gt[:,:,0::2,0::2] + gt[:,:,0::2,1::2] + gt[:,:,1::2,0::2] + gt[:,:,1::2,1::2] ) / 4.0 # we approximate the gt based on the 2x upsampled ones
with torch.inference_mode():
prediction, tiled_loss, c = tiled_pred(model, criterion, image1, image2, gt, conf_mode=conf_mode, overlap=args.val_overlap, crop=crop, with_conf=criterion.with_conf)
batch_metrics = metrics(prediction.detach(), gt)
loss = criterion(prediction.detach(), gt) if not criterion.with_conf else criterion(prediction.detach(), gt, c)
loss_value = loss.item()
metric_loggers[didx].update(loss_tiled=tiled_loss.item())
metric_loggers[didx].update(**{f'loss': loss_value})
for k,v in batch_metrics.items():
metric_loggers[didx].update(**{dname+'_' + k: v.item()})
results = {k: meter.global_avg for ml in metric_loggers for k, meter in ml.meters.items()}
if len(dnames)>1:
for k in batch_metrics.keys():
results['AVG_'+k] = sum(results[dname+'_'+k] for dname in dnames) / len(dnames)
if log_writer is not None :
epoch_1000x = int((1 + epoch) * 1000)
for k,v in results.items():
log_writer.add_scalar('val/'+k, v, epoch_1000x)
print("Averaged stats:", results)
return results
import torch.nn.functional as F
def _resize_img(img, new_size):
return F.interpolate(img, size=new_size, mode='bicubic', align_corners=False)
def _resize_stereo_or_flow(data, new_size):
assert data.ndim==4
assert data.size(1) in [1,2]
scale_x = new_size[1]/float(data.size(3))
out = F.interpolate(data, size=new_size, mode='bicubic', align_corners=False)
out[:,0,:,:] *= scale_x
if out.size(1)==2:
scale_y = new_size[0]/float(data.size(2))
out[:,1,:,:] *= scale_y
print(scale_x, new_size, data.shape)
return out
@torch.no_grad()
def tiled_pred(model, criterion, img1, img2, gt,
overlap=0.5, bad_crop_thr=0.05,
downscale=False, crop=512, ret='loss',
conf_mode='conf_expsigmoid_10_5', with_conf=False,
return_time=False):
# for each image, we are going to run inference on many overlapping patches
# then, all predictions will be weighted-averaged
if gt is not None:
B, C, H, W = gt.shape
else:
B, _, H, W = img1.shape
C = model.head.num_channels-int(with_conf)
win_height, win_width = crop[0], crop[1]
# upscale to be larger than the crop
do_change_scale = H<win_height or W<win_width
if do_change_scale:
upscale_factor = max(win_width/W, win_height/W)
original_size = (H,W)
new_size = (round(H*upscale_factor),round(W*upscale_factor))
img1 = _resize_img(img1, new_size)
img2 = _resize_img(img2, new_size)
# resize gt just for the computation of tiled losses
if gt is not None: gt = _resize_stereo_or_flow(gt, new_size)
H,W = img1.shape[2:4]
if conf_mode.startswith('conf_expsigmoid_'): # conf_expsigmoid_30_10
beta, betasigmoid = map(float, conf_mode[len('conf_expsigmoid_'):].split('_'))
elif conf_mode.startswith('conf_expbeta'): # conf_expbeta3
beta = float(conf_mode[len('conf_expbeta'):])
else:
raise NotImplementedError(f"conf_mode {conf_mode} is not implemented")
def crop_generator():
for sy in _overlapping(H, win_height, overlap):
for sx in _overlapping(W, win_width, overlap):
yield sy, sx, sy, sx, True
# keep track of weighted sum of prediction*weights and weights
accu_pred = img1.new_zeros((B, C, H, W)) # accumulate the weighted sum of predictions
accu_conf = img1.new_zeros((B, H, W)) + 1e-16 # accumulate the weights
accu_c = img1.new_zeros((B, H, W)) # accumulate the weighted sum of confidences ; not so useful except for computing some losses
tiled_losses = []
if return_time:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for sy1, sx1, sy2, sx2, aligned in crop_generator():
# compute optical flow there
pred = model(_crop(img1,sy1,sx1), _crop(img2,sy2,sx2))
pred, predconf = split_prediction_conf(pred, with_conf=with_conf)
if gt is not None: gtcrop = _crop(gt,sy1,sx1)
if criterion is not None and gt is not None:
tiled_losses.append( criterion(pred, gtcrop).item() if predconf is None else criterion(pred, gtcrop, predconf).item() )
if conf_mode.startswith('conf_expsigmoid_'):
conf = torch.exp(- beta * 2 * (torch.sigmoid(predconf / betasigmoid) - 0.5)).view(B,win_height,win_width)
elif conf_mode.startswith('conf_expbeta'):
conf = torch.exp(- beta * predconf).view(B,win_height,win_width)
else:
raise NotImplementedError
accu_pred[...,sy1,sx1] += pred * conf[:,None,:,:]
accu_conf[...,sy1,sx1] += conf
accu_c[...,sy1,sx1] += predconf.view(B,win_height,win_width) * conf
pred = accu_pred / accu_conf[:, None,:,:]
c = accu_c / accu_conf
assert not torch.any(torch.isnan(pred))
if return_time:
end.record()
torch.cuda.synchronize()
time = start.elapsed_time(end)/1000.0 # this was in milliseconds
if do_change_scale:
pred = _resize_stereo_or_flow(pred, original_size)
if return_time:
return pred, torch.mean(torch.tensor(tiled_losses)), c, time
return pred, torch.mean(torch.tensor(tiled_losses)), c
def _overlapping(total, window, overlap=0.5):
assert total >= window and 0 <= overlap < 1, (total, window, overlap)
num_windows = 1 + int(np.ceil( (total - window) / ((1-overlap) * window) ))
offsets = np.linspace(0, total-window, num_windows).round().astype(int)
yield from (slice(x, x+window) for x in offsets)
def _crop(img, sy, sx):
B, THREE, H, W = img.shape
if 0 <= sy.start and sy.stop <= H and 0 <= sx.start and sx.stop <= W:
return img[:,:,sy,sx]
l, r = max(0,-sx.start), max(0,sx.stop-W)
t, b = max(0,-sy.start), max(0,sy.stop-H)
img = torch.nn.functional.pad(img, (l,r,t,b), mode='constant')
return img[:, :, slice(sy.start+t,sy.stop+t), slice(sx.start+l,sx.stop+l)]