# Copyright (C) 2022-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # -------------------------------------------------------- # Main function for training one epoch or testing # -------------------------------------------------------- import math import sys from typing import Iterable import numpy as np import torch import torchvision from utils import misc as misc def split_prediction_conf(predictions, with_conf=False): if not with_conf: return predictions, None conf = predictions[:,-1:,:,:] predictions = predictions[:,:-1,:,:] return predictions, conf def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, metrics: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler, log_writer=None, print_freq = 20, args=None): model.train(True) metric_logger = misc.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) accum_iter = args.accum_iter optimizer.zero_grad() details = {} if log_writer is not None: print('log_dir: {}'.format(log_writer.log_dir)) if args.img_per_epoch: iter_per_epoch = args.img_per_epoch // args.batch_size + int(args.img_per_epoch % args.batch_size > 0) assert len(data_loader) >= iter_per_epoch, 'Dataset is too small for so many iterations' len_data_loader = iter_per_epoch else: len_data_loader, iter_per_epoch = len(data_loader), None for data_iter_step, (image1, image2, gt, pairname) in enumerate(metric_logger.log_every(data_loader, print_freq, header, max_iter=iter_per_epoch)): image1 = image1.to(device, non_blocking=True) image2 = image2.to(device, non_blocking=True) gt = gt.to(device, non_blocking=True) # we use a per iteration (instead of per epoch) lr scheduler if data_iter_step % accum_iter == 0: misc.adjust_learning_rate(optimizer, data_iter_step / len_data_loader + epoch, args) with torch.cuda.amp.autocast(enabled=bool(args.amp)): prediction = model(image1, image2) prediction, conf = split_prediction_conf(prediction, criterion.with_conf) batch_metrics = metrics(prediction.detach(), gt) loss = criterion(prediction, gt) if conf is None else criterion(prediction, gt, conf) loss_value = loss.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) loss /= accum_iter loss_scaler(loss, optimizer, parameters=model.parameters(), update_grad=(data_iter_step + 1) % accum_iter == 0) if (data_iter_step + 1) % accum_iter == 0: optimizer.zero_grad() torch.cuda.synchronize() metric_logger.update(loss=loss_value) for k,v in batch_metrics.items(): metric_logger.update(**{k: v.item()}) lr = optimizer.param_groups[0]["lr"] metric_logger.update(lr=lr) #if args.dsitributed: loss_value_reduce = misc.all_reduce_mean(loss_value) time_to_log = ((data_iter_step + 1) % (args.tboard_log_step * accum_iter) == 0 or data_iter_step == len_data_loader-1) loss_value_reduce = misc.all_reduce_mean(loss_value) if log_writer is not None and time_to_log: epoch_1000x = int((data_iter_step / len_data_loader + epoch) * 1000) # We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes. log_writer.add_scalar('train/loss', loss_value_reduce, epoch_1000x) log_writer.add_scalar('lr', lr, epoch_1000x) for k,v in batch_metrics.items(): log_writer.add_scalar('train/'+k, v.item(), epoch_1000x) # gather the stats from all processes #if args.distributed: metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} @torch.no_grad() def validate_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, metrics: torch.nn.Module, data_loaders: list[Iterable], device: torch.device, epoch: int, log_writer=None, args=None): model.eval() metric_loggers = [] header = 'Epoch: [{}]'.format(epoch) print_freq = 20 conf_mode = args.tile_conf_mode crop = args.crop if log_writer is not None: print('log_dir: {}'.format(log_writer.log_dir)) results = {} dnames = [] image1, image2, gt, prediction = None, None, None, None for didx, data_loader in enumerate(data_loaders): dname = str(data_loader.dataset) dnames.append(dname) metric_loggers.append(misc.MetricLogger(delimiter=" ")) for data_iter_step, (image1, image2, gt, pairname) in enumerate(metric_loggers[didx].log_every(data_loader, print_freq, header)): image1 = image1.to(device, non_blocking=True) image2 = image2.to(device, non_blocking=True) gt = gt.to(device, non_blocking=True) if dname.startswith('Spring'): assert gt.size(2)==image1.size(2)*2 and gt.size(3)==image1.size(3)*2 gt = (gt[:,:,0::2,0::2] + gt[:,:,0::2,1::2] + gt[:,:,1::2,0::2] + gt[:,:,1::2,1::2] ) / 4.0 # we approximate the gt based on the 2x upsampled ones with torch.inference_mode(): prediction, tiled_loss, c = tiled_pred(model, criterion, image1, image2, gt, conf_mode=conf_mode, overlap=args.val_overlap, crop=crop, with_conf=criterion.with_conf) batch_metrics = metrics(prediction.detach(), gt) loss = criterion(prediction.detach(), gt) if not criterion.with_conf else criterion(prediction.detach(), gt, c) loss_value = loss.item() metric_loggers[didx].update(loss_tiled=tiled_loss.item()) metric_loggers[didx].update(**{f'loss': loss_value}) for k,v in batch_metrics.items(): metric_loggers[didx].update(**{dname+'_' + k: v.item()}) results = {k: meter.global_avg for ml in metric_loggers for k, meter in ml.meters.items()} if len(dnames)>1: for k in batch_metrics.keys(): results['AVG_'+k] = sum(results[dname+'_'+k] for dname in dnames) / len(dnames) if log_writer is not None : epoch_1000x = int((1 + epoch) * 1000) for k,v in results.items(): log_writer.add_scalar('val/'+k, v, epoch_1000x) print("Averaged stats:", results) return results import torch.nn.functional as F def _resize_img(img, new_size): return F.interpolate(img, size=new_size, mode='bicubic', align_corners=False) def _resize_stereo_or_flow(data, new_size): assert data.ndim==4 assert data.size(1) in [1,2] scale_x = new_size[1]/float(data.size(3)) out = F.interpolate(data, size=new_size, mode='bicubic', align_corners=False) out[:,0,:,:] *= scale_x if out.size(1)==2: scale_y = new_size[0]/float(data.size(2)) out[:,1,:,:] *= scale_y print(scale_x, new_size, data.shape) return out @torch.no_grad() def tiled_pred(model, criterion, img1, img2, gt, overlap=0.5, bad_crop_thr=0.05, downscale=False, crop=512, ret='loss', conf_mode='conf_expsigmoid_10_5', with_conf=False, return_time=False): # for each image, we are going to run inference on many overlapping patches # then, all predictions will be weighted-averaged if gt is not None: B, C, H, W = gt.shape else: B, _, H, W = img1.shape C = model.head.num_channels-int(with_conf) win_height, win_width = crop[0], crop[1] # upscale to be larger than the crop do_change_scale = H= window and 0 <= overlap < 1, (total, window, overlap) num_windows = 1 + int(np.ceil( (total - window) / ((1-overlap) * window) )) offsets = np.linspace(0, total-window, num_windows).round().astype(int) yield from (slice(x, x+window) for x in offsets) def _crop(img, sy, sx): B, THREE, H, W = img.shape if 0 <= sy.start and sy.stop <= H and 0 <= sx.start and sx.stop <= W: return img[:,:,sy,sx] l, r = max(0,-sx.start), max(0,sx.stop-W) t, b = max(0,-sy.start), max(0,sy.stop-H) img = torch.nn.functional.pad(img, (l,r,t,b), mode='constant') return img[:, :, slice(sy.start+t,sy.stop+t), slice(sx.start+l,sx.stop+l)]