import os import torch import matplotlib.pyplot as plt from mono.model.monodepth_model import get_configured_monodepth_model from tensorboardX import SummaryWriter from mono.utils.comm import TrainingStats from mono.utils.avg_meter import MetricAverageMeter from mono.utils.running import build_lr_schedule_with_cfg, build_optimizer_with_cfg, load_ckpt, save_ckpt from mono.utils.comm import reduce_dict, main_process, get_rank from mono.utils.visualization import save_val_imgs, visual_train_data, create_html, save_normal_val_imgs import traceback from mono.utils.visualization import create_dir_for_validate_meta from mono.model.criterion import build_criterions from mono.datasets.distributed_sampler import build_dataset_n_sampler_with_cfg, build_data_array from mono.utils.logger import setup_logger import logging from .misc import NativeScalerWithGradNormCount, is_bf16_supported import math import sys import random import numpy as np import torch.distributed as dist import torch.nn.functional as F from contextlib import nullcontext def to_cuda(data): for k, v in data.items(): if isinstance(v, torch.Tensor): data[k] = v.cuda(non_blocking=True) if isinstance(v, list) and len(v)>1 and isinstance(v[0], torch.Tensor): for i, l_i in enumerate(v): data[k][i] = l_i.cuda(non_blocking=True) return data def do_train(local_rank: int, cfg: dict): logger = setup_logger(cfg.log_file) # build criterions criterions = build_criterions(cfg) # build model model = get_configured_monodepth_model(cfg, criterions, ) # log model state_dict if main_process(): logger.info(model.state_dict().keys()) # build datasets train_dataset, train_sampler = build_dataset_n_sampler_with_cfg(cfg, 'train') if 'multi_dataset_eval' in cfg.evaluation and cfg.evaluation.multi_dataset_eval: val_dataset = build_data_array(cfg, 'val') else: val_dataset, val_sampler = build_dataset_n_sampler_with_cfg(cfg, 'val') # build data loaders g = torch.Generator() g.manual_seed(cfg.seed + cfg.dist_params.global_rank) train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=cfg.batchsize_per_gpu, num_workers=cfg.thread_per_gpu, sampler=train_sampler, drop_last=True, pin_memory=True, generator=g,) # collate_fn=collate_fn) if isinstance(val_dataset, list): val_dataloader = [torch.utils.data.DataLoader(dataset=val_dataset, batch_size=1, num_workers=0, sampler=torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False), drop_last=True, pin_memory=True,) for val_group in val_dataset for val_dataset in val_group] else: val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=1, num_workers=0, sampler=val_sampler, drop_last=True, pin_memory=True,) # build schedule lr_scheduler = build_lr_schedule_with_cfg(cfg) optimizer = build_optimizer_with_cfg(cfg, model) # config distributed training if cfg.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False) else: model = torch.nn.DataParallel(model.cuda()) # init automatic mix precision training # if 'AMP' in cfg.runner.type: # loss_scaler = NativeScalerWithGradNormCount() # else: # loss_scaler = None loss_scaler = None # load ckpt if cfg.load_from and cfg.resume_from is None: model, _, _, loss_scaler = load_ckpt(cfg.load_from, model, optimizer=None, scheduler=None, strict_match=False, loss_scaler=loss_scaler) elif cfg.resume_from: model, optimizer, lr_scheduler, loss_scaler = load_ckpt( cfg.resume_from, model, optimizer=optimizer, scheduler=lr_scheduler, strict_match=False, loss_scaler=loss_scaler) if cfg.runner.type == 'IterBasedRunner': train_by_iters(cfg, model, optimizer, lr_scheduler, train_dataloader, val_dataloader, ) elif cfg.runner.type == 'IterBasedRunner_MultiSize': train_by_iters_multisize(cfg, model, optimizer, lr_scheduler, train_dataloader, val_dataloader, ) elif cfg.runner.type == 'IterBasedRunner_AMP': train_by_iters_amp( cfg = cfg, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, loss_scaler=loss_scaler ) elif cfg.runner.type == 'IterBasedRunner_AMP_MultiSize': train_by_iters_amp_multisize( cfg = cfg, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, loss_scaler=loss_scaler ) elif cfg.runner.type == 'EpochBasedRunner': raise RuntimeError('It is not supported currently. :)') else: raise RuntimeError('It is not supported currently. :)') def train_by_iters(cfg, model, optimizer, lr_scheduler, train_dataloader, val_dataloader): """ Do the training by iterations. """ logger = logging.getLogger() tb_logger = None if cfg.use_tensorboard and main_process(): tb_logger = SummaryWriter(cfg.tensorboard_dir) if main_process(): training_stats = TrainingStats(log_period=cfg.log_interval, tensorboard_logger=tb_logger) lr_scheduler.before_run(optimizer) # set training steps max_iters = cfg.runner.max_iters start_iter = lr_scheduler._step_count save_interval = cfg.checkpoint_config.interval eval_interval = cfg.evaluation.interval epoch = 0 logger.info('Create iterator.') dataloader_iterator = iter(train_dataloader) val_err = {} logger.info('Start training.') try: # for step in range(start_iter, max_iters): # keep same step in all processes, avoid stuck during eval barrier step = start_iter while step < max_iters: if main_process(): training_stats.IterTic() # get the data batch try: data = next(dataloader_iterator) except StopIteration: dataloader_iterator = iter(train_dataloader) data = next(dataloader_iterator) except Exception as e: logger.info('When load training data: ', e) continue except: logger.info('Some training data errors exist in the current iter!') continue data = to_cuda(data) # set random crop size # if step % 10 == 0: # set_random_crop_size_for_iter(train_dataloader, step, size_sample_list[step]) # check training data #for i in range(data['target'].shape[0]): # if 'DDAD' in data['dataset'][i] or \ # 'Lyft' in data['dataset'][i] or \ # 'DSEC' in data['dataset'][i] or \ # 'Argovers2' in data['dataset'][i]: # replace = True # else: # replace = False #visual_train_data(data['target'][i, ...], data['input'][i,...], data['filename'][i], cfg.work_dir, replace=replace) # forward pred_depth, losses_dict, conf = model(data) optimizer.zero_grad() losses_dict['total_loss'].backward() # if step > 100 and step % 10 == 0: # for param in model.parameters(): # print(param.grad.max(), torch.norm(param.grad)) torch.nn.utils.clip_grad_norm_(model.parameters(), 10) optimizer.step() # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_dict(losses_dict) lr_scheduler.after_train_iter(optimizer) if main_process(): training_stats.update_iter_stats(loss_dict_reduced) training_stats.IterToc() training_stats.log_iter_stats(step, optimizer, max_iters, val_err) # validate the model if cfg.evaluation.online_eval and \ (step+1) % eval_interval == 0 and \ val_dataloader is not None: if isinstance(val_dataloader, list): val_err = validate_multiple_dataset(cfg, step+1, model, val_dataloader, tb_logger) else: val_err = validate(cfg, step+1, model, val_dataloader, tb_logger) if main_process(): training_stats.tb_log_stats(val_err, step) # save checkpoint if main_process(): if ((step+1) % save_interval == 0) or ((step+1)==max_iters): save_ckpt(cfg, model, optimizer, lr_scheduler, step+1, epoch) step += 1 except (RuntimeError, KeyboardInterrupt): stack_trace = traceback.format_exc() print(stack_trace) def train_by_iters_amp(cfg, model, optimizer, lr_scheduler, train_dataloader, val_dataloader, loss_scaler): """ Do the training by iterations. Mix precision is employed. """ # set up logger tb_logger = None if cfg.use_tensorboard and main_process(): tb_logger = SummaryWriter(cfg.tensorboard_dir) logger = logging.getLogger() # training status if main_process(): training_stats = TrainingStats(log_period=cfg.log_interval, tensorboard_logger=tb_logger) # learning schedule lr_scheduler.before_run(optimizer) # set training steps max_iters = cfg.runner.max_iters start_iter = lr_scheduler._step_count save_interval = cfg.checkpoint_config.interval eval_interval = cfg.evaluation.interval epoch = 0 # If it's too slow try lowering num_worker # see https://discuss.pytorch.org/t/define-iterator-on-dataloader-is-very-slow/52238 logger.info('Create iterator.') dataloader_iterator = iter(train_dataloader) val_err = {} # torch.cuda.empty_cache() logger.info('Start training.') try: acc_batch = cfg.acc_batch except: acc_batch = 1 try: # for step in range(start_iter, max_iters): # keep same step in all processes, avoid stuck during eval barrier step = start_iter * acc_batch #while step < max_iters: while True: if main_process(): training_stats.IterTic() # get the data batch try: data = next(dataloader_iterator) except StopIteration: dataloader_iterator = iter(train_dataloader) data = next(dataloader_iterator) except Exception as e: logger.info('When load training data: ', e) continue except: logger.info('Some training data errors exist in the current iter!') continue data = to_cuda(data) with torch.cuda.amp.autocast(dtype=torch.bfloat16): pred_depth, losses_dict, conf = model(data) total_loss = losses_dict['total_loss'] / acc_batch if not math.isfinite(total_loss): logger.info("Loss is {}, skiping this batch training".format(total_loss)) continue # optimize, backward if (step+1-start_iter) % acc_batch == 0: optimizer.zero_grad() if loss_scaler == None: total_loss.backward() try: if (step+1-start_iter) % acc_batch == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 2.5, error_if_nonfinite=True) optimizer.step() except: print('NAN gradient, skipping optimizer.step() for this round...') else: loss_scaler(total_loss, optimizer, clip_grad=5, parameters=model.parameters(), update_grad=True) # reduce losses over all GPUs for logging purposes if (step+1-start_iter) % acc_batch == 0: loss_dict_reduced = reduce_dict(losses_dict) lr_scheduler.after_train_iter(optimizer) if main_process(): training_stats.update_iter_stats(loss_dict_reduced) training_stats.IterToc() training_stats.log_iter_stats(step//acc_batch, optimizer, max_iters, val_err) # validate the model if cfg.evaluation.online_eval and \ ((step+acc_batch)//acc_batch) % eval_interval == 0 and \ val_dataloader is not None: # if True: if isinstance(val_dataloader, list): val_err = validate_multiple_dataset(cfg, ((step+acc_batch)//acc_batch), model, val_dataloader, tb_logger) else: val_err = validate(cfg, ((step+acc_batch)//acc_batch), model, val_dataloader, tb_logger) if main_process(): training_stats.tb_log_stats(val_err, step) # save checkpoint if main_process(): if (((step+acc_batch)//acc_batch) % save_interval == 0) or (((step+acc_batch)//acc_batch)==max_iters): save_ckpt(cfg, model, optimizer, lr_scheduler, ((step+acc_batch)//acc_batch), epoch, loss_scaler=loss_scaler) step += 1 except (RuntimeError, KeyboardInterrupt): stack_trace = traceback.format_exc() print(stack_trace) def validate_multiple_dataset(cfg, iter, model, val_dataloaders, tb_logger): val_errs = {} for val_dataloader in val_dataloaders: val_err = validate(cfg, iter, model, val_dataloader, tb_logger) val_errs.update(val_err) # mean of all dataset mean_val_err = {} for k, v in val_errs.items(): metric = 'AllData_eval/' + k.split('/')[-1] if metric not in mean_val_err.keys(): mean_val_err[metric] = 0 mean_val_err[metric] += v / len(val_dataloaders) val_errs.update(mean_val_err) return val_errs def validate(cfg, iter, model, val_dataloader, tb_logger): """ Validate the model on single dataset """ model.eval() dist.barrier() logger = logging.getLogger() # prepare dir for visualization data save_val_meta_data_dir = create_dir_for_validate_meta(cfg.work_dir, iter) # save_html_path = save_val_meta_data_dir + '.html' dataset_name = val_dataloader.dataset.data_name save_point = max(int(len(val_dataloader) / 5), 1) # save_point = 2 # depth metric meter dam = MetricAverageMeter(cfg.evaluation.metrics) # dam_disp = MetricAverageMeter([m for m in cfg.evaluation.metrics if m[:6]!='normal']) for i, data in enumerate(val_dataloader): if i % 10 == 0: logger.info(f'Validation step on {dataset_name}: {i}') data = to_cuda(data) output = model.module.inference(data) pred_depth = output['prediction'] pred_depth = pred_depth.squeeze() gt_depth = data['target'].cuda(non_blocking=True).squeeze() pad = data['pad'].squeeze() H, W = pred_depth.shape pred_depth = pred_depth[pad[0]:H-pad[1], pad[2]:W-pad[3]] gt_depth = gt_depth[pad[0]:H-pad[1], pad[2]:W-pad[3]] rgb = data['input'][0, :, pad[0]:H-pad[1], pad[2]:W-pad[3]] mask = gt_depth > 0 #pred_depth_resize = cv2.resize(pred_depth.cpu().numpy(), (torch.squeeze(data['B_raw']).shape[1], torch.squeeze(data['B_raw']).shape[0])) dam.update_metrics_gpu(pred_depth, gt_depth, mask, cfg.distributed) # save evaluation results if i%save_point == 0 and main_process(): save_val_imgs(iter, pred_depth, gt_depth, rgb, # data['input'], dataset_name + '_' + data['filename'][0], save_val_meta_data_dir, tb_logger=tb_logger) ## surface normal if "normal_out_list" in output.keys(): normal_out_list = output['normal_out_list'] pred_normal = normal_out_list[-1][:, :3, :, :] # (B, 3, H, W) gt_normal = data['normal'].cuda(non_blocking=True) # if pred_normal.shape != gt_normal.shape: # pred_normal = F.interpolate(pred_normal, size=[gt_normal.size(2), gt_normal.size(3)], mode='bilinear', align_corners=True) H, W = pred_normal.shape[2:] pred_normal = pred_normal[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3]] gt_normal = gt_normal[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3]] gt_normal_mask = ~torch.all(gt_normal == 0, dim=1, keepdim=True) dam.update_normal_metrics_gpu(pred_normal, gt_normal, gt_normal_mask, cfg.distributed) # save valiad normal if i%save_point == 0 and main_process(): save_normal_val_imgs(iter, pred_normal, gt_normal, rgb, # data['input'], dataset_name + '_normal_' + data['filename'][0], save_val_meta_data_dir, tb_logger=tb_logger) # create html for visualization merged_rgb_pred_gt = os.path.join(save_val_meta_data_dir, '*_merge.jpg') name2path = dict(merg=merged_rgb_pred_gt) #dict(rgbs=rgbs, pred=pred, gt=gt) # if main_process(): # create_html(name2path, save_path=save_html_path, size=(256*3, 512)) # get validation error eval_error = dam.get_metrics() eval_error = {f'{dataset_name}_eval/{k}': v for k,v in eval_error.items()} # eval_disp_error = {f'{dataset_name}_eval/disp_{k}': v for k,v in dam_disp.get_metrics().items()} # eval_error.update(eval_disp_error) model.train() if 'exclude' in cfg.evaluation and dataset_name in cfg.evaluation.exclude: return {} return eval_error def set_random_crop_size_for_iter(dataloader: torch.utils.data.dataloader.DataLoader, iter: int, size_pool=None): if size_pool is None: size_pool = [ # [504, 504], [560, 1008], [840, 1512], [1120, 2016], [560, 1008], [840, 1512], [1120, 2016], # [480, 768], [480, 960], # [480, 992], [480, 1024], # [480, 1120], # [480, 1280], # [480, 1312], # [512, 512], [512, 640], # [512, 960], # [512, 992], # [512, 1024], [512, 1120], # [512, 1216], # [512, 1280], # [576, 640], [576, 960], # [576, 992], # [576, 1024], # [608, 608], [608, 640], # [608, 960], [608, 1024], ] random.seed(iter) sample = random.choice(size_pool) # idx = (iter // 10) % len(size_pool) #sample = size_pool[size_idx] # random.seed(iter) # flg = random.random() <= 1.0 # if flg: crop_size = sample # else: # crop_size = [sample[1], sample[0]] # set crop size for each dataset datasets_groups = len(dataloader.dataset.datasets) for i in range(datasets_groups): for j in range(len(dataloader.dataset.datasets[i].datasets)): dataloader.dataset.datasets[i].datasets[j].set_random_crop_size(crop_size) return crop_size