# Copyright (C) 2022-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # -------------------------------------------------------- # Main function for training one epoch or testing # -------------------------------------------------------- import math import sys from typing import Iterable import numpy as np import torch import torchvision from utils import misc as misc def split_prediction_conf(predictions, with_conf=False): if not with_conf: return predictions, None conf = predictions[:, -1:, :, :] predictions = predictions[:, :-1, :, :] return predictions, conf def train_one_epoch( model: torch.nn.Module, criterion: torch.nn.Module, metrics: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler, log_writer=None, print_freq=20, args=None, ): model.train(True) metric_logger = misc.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")) header = "Epoch: [{}]".format(epoch) accum_iter = args.accum_iter optimizer.zero_grad() details = {} if log_writer is not None: print("log_dir: {}".format(log_writer.log_dir)) if args.img_per_epoch: iter_per_epoch = args.img_per_epoch // args.batch_size + int( args.img_per_epoch % args.batch_size > 0 ) assert ( len(data_loader) >= iter_per_epoch ), "Dataset is too small for so many iterations" len_data_loader = iter_per_epoch else: len_data_loader, iter_per_epoch = len(data_loader), None for data_iter_step, (image1, image2, gt, pairname) in enumerate( metric_logger.log_every( data_loader, print_freq, header, max_iter=iter_per_epoch ) ): image1 = image1.to(device, non_blocking=True) image2 = image2.to(device, non_blocking=True) gt = gt.to(device, non_blocking=True) # we use a per iteration (instead of per epoch) lr scheduler if data_iter_step % accum_iter == 0: misc.adjust_learning_rate( optimizer, data_iter_step / len_data_loader + epoch, args ) with torch.cuda.amp.autocast(enabled=bool(args.amp)): prediction = model(image1, image2) prediction, conf = split_prediction_conf(prediction, criterion.with_conf) batch_metrics = metrics(prediction.detach(), gt) loss = ( criterion(prediction, gt) if conf is None else criterion(prediction, gt, conf) ) loss_value = loss.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) loss /= accum_iter loss_scaler( loss, optimizer, parameters=model.parameters(), update_grad=(data_iter_step + 1) % accum_iter == 0, ) if (data_iter_step + 1) % accum_iter == 0: optimizer.zero_grad() torch.cuda.synchronize() metric_logger.update(loss=loss_value) for k, v in batch_metrics.items(): metric_logger.update(**{k: v.item()}) lr = optimizer.param_groups[0]["lr"] metric_logger.update(lr=lr) # if args.dsitributed: loss_value_reduce = misc.all_reduce_mean(loss_value) time_to_log = (data_iter_step + 1) % ( args.tboard_log_step * accum_iter ) == 0 or data_iter_step == len_data_loader - 1 loss_value_reduce = misc.all_reduce_mean(loss_value) if log_writer is not None and time_to_log: epoch_1000x = int((data_iter_step / len_data_loader + epoch) * 1000) # We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes. log_writer.add_scalar("train/loss", loss_value_reduce, epoch_1000x) log_writer.add_scalar("lr", lr, epoch_1000x) for k, v in batch_metrics.items(): log_writer.add_scalar("train/" + k, v.item(), epoch_1000x) # gather the stats from all processes # if args.distributed: metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} @torch.no_grad() def validate_one_epoch( model: torch.nn.Module, criterion: torch.nn.Module, metrics: torch.nn.Module, data_loaders: list[Iterable], device: torch.device, epoch: int, log_writer=None, args=None, ): model.eval() metric_loggers = [] header = "Epoch: [{}]".format(epoch) print_freq = 20 conf_mode = args.tile_conf_mode crop = args.crop if log_writer is not None: print("log_dir: {}".format(log_writer.log_dir)) results = {} dnames = [] image1, image2, gt, prediction = None, None, None, None for didx, data_loader in enumerate(data_loaders): dname = str(data_loader.dataset) dnames.append(dname) metric_loggers.append(misc.MetricLogger(delimiter=" ")) for data_iter_step, (image1, image2, gt, pairname) in enumerate( metric_loggers[didx].log_every(data_loader, print_freq, header) ): image1 = image1.to(device, non_blocking=True) image2 = image2.to(device, non_blocking=True) gt = gt.to(device, non_blocking=True) if dname.startswith("Spring"): assert ( gt.size(2) == image1.size(2) * 2 and gt.size(3) == image1.size(3) * 2 ) gt = ( gt[:, :, 0::2, 0::2] + gt[:, :, 0::2, 1::2] + gt[:, :, 1::2, 0::2] + gt[:, :, 1::2, 1::2] ) / 4.0 # we approximate the gt based on the 2x upsampled ones with torch.inference_mode(): prediction, tiled_loss, c = tiled_pred( model, criterion, image1, image2, gt, conf_mode=conf_mode, overlap=args.val_overlap, crop=crop, with_conf=criterion.with_conf, ) batch_metrics = metrics(prediction.detach(), gt) loss = ( criterion(prediction.detach(), gt) if not criterion.with_conf else criterion(prediction.detach(), gt, c) ) loss_value = loss.item() metric_loggers[didx].update(loss_tiled=tiled_loss.item()) metric_loggers[didx].update(**{f"loss": loss_value}) for k, v in batch_metrics.items(): metric_loggers[didx].update(**{dname + "_" + k: v.item()}) results = { k: meter.global_avg for ml in metric_loggers for k, meter in ml.meters.items() } if len(dnames) > 1: for k in batch_metrics.keys(): results["AVG_" + k] = sum( results[dname + "_" + k] for dname in dnames ) / len(dnames) if log_writer is not None: epoch_1000x = int((1 + epoch) * 1000) for k, v in results.items(): log_writer.add_scalar("val/" + k, v, epoch_1000x) print("Averaged stats:", results) return results import torch.nn.functional as F def _resize_img(img, new_size): return F.interpolate(img, size=new_size, mode="bicubic", align_corners=False) def _resize_stereo_or_flow(data, new_size): assert data.ndim == 4 assert data.size(1) in [1, 2] scale_x = new_size[1] / float(data.size(3)) out = F.interpolate(data, size=new_size, mode="bicubic", align_corners=False) out[:, 0, :, :] *= scale_x if out.size(1) == 2: scale_y = new_size[0] / float(data.size(2)) out[:, 1, :, :] *= scale_y print(scale_x, new_size, data.shape) return out @torch.no_grad() def tiled_pred( model, criterion, img1, img2, gt, overlap=0.5, bad_crop_thr=0.05, downscale=False, crop=512, ret="loss", conf_mode="conf_expsigmoid_10_5", with_conf=False, return_time=False, ): # for each image, we are going to run inference on many overlapping patches # then, all predictions will be weighted-averaged if gt is not None: B, C, H, W = gt.shape else: B, _, H, W = img1.shape C = model.head.num_channels - int(with_conf) win_height, win_width = crop[0], crop[1] # upscale to be larger than the crop do_change_scale = H < win_height or W < win_width if do_change_scale: upscale_factor = max(win_width / W, win_height / W) original_size = (H, W) new_size = (round(H * upscale_factor), round(W * upscale_factor)) img1 = _resize_img(img1, new_size) img2 = _resize_img(img2, new_size) # resize gt just for the computation of tiled losses if gt is not None: gt = _resize_stereo_or_flow(gt, new_size) H, W = img1.shape[2:4] if conf_mode.startswith("conf_expsigmoid_"): # conf_expsigmoid_30_10 beta, betasigmoid = map(float, conf_mode[len("conf_expsigmoid_") :].split("_")) elif conf_mode.startswith("conf_expbeta"): # conf_expbeta3 beta = float(conf_mode[len("conf_expbeta") :]) else: raise NotImplementedError(f"conf_mode {conf_mode} is not implemented") def crop_generator(): for sy in _overlapping(H, win_height, overlap): for sx in _overlapping(W, win_width, overlap): yield sy, sx, sy, sx, True # keep track of weighted sum of prediction*weights and weights accu_pred = img1.new_zeros( (B, C, H, W) ) # accumulate the weighted sum of predictions accu_conf = img1.new_zeros((B, H, W)) + 1e-16 # accumulate the weights accu_c = img1.new_zeros( (B, H, W) ) # accumulate the weighted sum of confidences ; not so useful except for computing some losses tiled_losses = [] if return_time: start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for sy1, sx1, sy2, sx2, aligned in crop_generator(): # compute optical flow there pred = model(_crop(img1, sy1, sx1), _crop(img2, sy2, sx2)) pred, predconf = split_prediction_conf(pred, with_conf=with_conf) if gt is not None: gtcrop = _crop(gt, sy1, sx1) if criterion is not None and gt is not None: tiled_losses.append( criterion(pred, gtcrop).item() if predconf is None else criterion(pred, gtcrop, predconf).item() ) if conf_mode.startswith("conf_expsigmoid_"): conf = torch.exp( -beta * 2 * (torch.sigmoid(predconf / betasigmoid) - 0.5) ).view(B, win_height, win_width) elif conf_mode.startswith("conf_expbeta"): conf = torch.exp(-beta * predconf).view(B, win_height, win_width) else: raise NotImplementedError accu_pred[..., sy1, sx1] += pred * conf[:, None, :, :] accu_conf[..., sy1, sx1] += conf accu_c[..., sy1, sx1] += predconf.view(B, win_height, win_width) * conf pred = accu_pred / accu_conf[:, None, :, :] c = accu_c / accu_conf assert not torch.any(torch.isnan(pred)) if return_time: end.record() torch.cuda.synchronize() time = start.elapsed_time(end) / 1000.0 # this was in milliseconds if do_change_scale: pred = _resize_stereo_or_flow(pred, original_size) if return_time: return pred, torch.mean(torch.tensor(tiled_losses)), c, time return pred, torch.mean(torch.tensor(tiled_losses)), c def _overlapping(total, window, overlap=0.5): assert total >= window and 0 <= overlap < 1, (total, window, overlap) num_windows = 1 + int(np.ceil((total - window) / ((1 - overlap) * window))) offsets = np.linspace(0, total - window, num_windows).round().astype(int) yield from (slice(x, x + window) for x in offsets) def _crop(img, sy, sx): B, THREE, H, W = img.shape if 0 <= sy.start and sy.stop <= H and 0 <= sx.start and sx.stop <= W: return img[:, :, sy, sx] l, r = max(0, -sx.start), max(0, sx.stop - W) t, b = max(0, -sy.start), max(0, sy.stop - H) img = torch.nn.functional.pad(img, (l, r, t, b), mode="constant") return img[:, :, slice(sy.start + t, sy.stop + t), slice(sx.start + l, sx.stop + l)]