from collections import defaultdict import pprint from loguru import logger from pathlib import Path import torch import numpy as np import pytorch_lightning as pl from matplotlib import pyplot as plt from src.models import TopicFM from src.models.utils.supervision import ( compute_supervision_coarse, compute_supervision_fine, ) from src.losses.loss import TopicFMLoss from src.optimizers import build_optimizer, build_scheduler from src.utils.metrics import ( compute_symmetrical_epipolar_errors, compute_pose_errors, aggregate_metrics, ) from src.utils.plotting import make_matching_figures from src.utils.comm import gather, all_gather from src.utils.misc import lower_config, flattenList from src.utils.profiler import PassThroughProfiler class PL_Trainer(pl.LightningModule): def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None): """ TODO: - use the new version of PL logging API. """ super().__init__() # Misc self.config = config # full config _config = lower_config(self.config) self.model_cfg = lower_config(_config["model"]) self.profiler = profiler or PassThroughProfiler() self.n_vals_plot = max( config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1 ) # Matcher: TopicFM self.matcher = TopicFM(config=_config["model"]) self.loss = TopicFMLoss(_config) # Pretrained weights if pretrained_ckpt: state_dict = torch.load(pretrained_ckpt, map_location="cpu")["state_dict"] self.matcher.load_state_dict(state_dict, strict=True) logger.info(f"Load '{pretrained_ckpt}' as pretrained checkpoint") # Testing self.dump_dir = dump_dir def configure_optimizers(self): # FIXME: The scheduler did not work properly when `--resume_from_checkpoint` optimizer = build_optimizer(self, self.config) scheduler = build_scheduler(self.config, optimizer) return [optimizer], [scheduler] def optimizer_step( self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs, ): # learning rate warm up warmup_step = self.config.TRAINER.WARMUP_STEP if self.trainer.global_step < warmup_step: if self.config.TRAINER.WARMUP_TYPE == "linear": base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR lr = base_lr + ( self.trainer.global_step / self.config.TRAINER.WARMUP_STEP ) * abs(self.config.TRAINER.TRUE_LR - base_lr) for pg in optimizer.param_groups: pg["lr"] = lr elif self.config.TRAINER.WARMUP_TYPE == "constant": pass else: raise ValueError( f"Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}" ) # update params optimizer.step(closure=optimizer_closure) optimizer.zero_grad() def _trainval_inference(self, batch): with self.profiler.profile("Compute coarse supervision"): compute_supervision_coarse(batch, self.config) with self.profiler.profile("TopicFM"): self.matcher(batch) with self.profiler.profile("Compute fine supervision"): compute_supervision_fine(batch, self.config) with self.profiler.profile("Compute losses"): self.loss(batch) def _compute_metrics(self, batch): with self.profiler.profile("Copmute metrics"): compute_symmetrical_epipolar_errors( batch ) # compute epi_errs for each match compute_pose_errors( batch, self.config ) # compute R_errs, t_errs, pose_errs for each pair rel_pair_names = list(zip(*batch["pair_names"])) bs = batch["image0"].size(0) metrics = { # to filter duplicate pairs caused by DistributedSampler "identifiers": ["#".join(rel_pair_names[b]) for b in range(bs)], "epi_errs": [ batch["epi_errs"][batch["m_bids"] == b].cpu().numpy() for b in range(bs) ], "R_errs": batch["R_errs"], "t_errs": batch["t_errs"], "inliers": batch["inliers"], } ret_dict = {"metrics": metrics} return ret_dict, rel_pair_names def training_step(self, batch, batch_idx): self._trainval_inference(batch) # logging if ( self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0 ): # scalars for k, v in batch["loss_scalars"].items(): self.logger.experiment.add_scalar(f"train/{k}", v, self.global_step) # figures if self.config.TRAINER.ENABLE_PLOTTING: compute_symmetrical_epipolar_errors( batch ) # compute epi_errs for each match figures = make_matching_figures( batch, self.config, self.config.TRAINER.PLOT_MODE ) for k, v in figures.items(): self.logger.experiment.add_figure( f"train_match/{k}", v, self.global_step ) return {"loss": batch["loss"]} def training_epoch_end(self, outputs): avg_loss = torch.stack([x["loss"] for x in outputs]).mean() if self.trainer.global_rank == 0: self.logger.experiment.add_scalar( "train/avg_loss_on_epoch", avg_loss, global_step=self.current_epoch ) def validation_step(self, batch, batch_idx): self._trainval_inference(batch) ret_dict, _ = self._compute_metrics(batch) val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1) figures = {self.config.TRAINER.PLOT_MODE: []} if batch_idx % val_plot_interval == 0: figures = make_matching_figures( batch, self.config, mode=self.config.TRAINER.PLOT_MODE ) return { **ret_dict, "loss_scalars": batch["loss_scalars"], "figures": figures, } def validation_epoch_end(self, outputs): # handle multiple validation sets multi_outputs = ( [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs ) multi_val_metrics = defaultdict(list) for valset_idx, outputs in enumerate(multi_outputs): # since pl performs sanity_check at the very begining of the training cur_epoch = self.trainer.current_epoch if ( not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check ): cur_epoch = -1 # 1. loss_scalars: dict of list, on cpu _loss_scalars = [o["loss_scalars"] for o in outputs] loss_scalars = { k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0] } # 2. val metrics: dict of list, numpy _metrics = [o["metrics"] for o in outputs] metrics = { k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0] } # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0 val_metrics_4tb = aggregate_metrics( metrics, self.config.TRAINER.EPI_ERR_THR ) for thr in [5, 10, 20]: multi_val_metrics[f"auc@{thr}"].append(val_metrics_4tb[f"auc@{thr}"]) # 3. figures _figures = [o["figures"] for o in outputs] figures = { k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0] } # tensorboard records only on rank 0 if self.trainer.global_rank == 0: for k, v in loss_scalars.items(): mean_v = torch.stack(v).mean() self.logger.experiment.add_scalar( f"val_{valset_idx}/avg_{k}", mean_v, global_step=cur_epoch ) for k, v in val_metrics_4tb.items(): self.logger.experiment.add_scalar( f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch ) for k, v in figures.items(): if self.trainer.global_rank == 0: for plot_idx, fig in enumerate(v): self.logger.experiment.add_figure( f"val_match_{valset_idx}/{k}/pair-{plot_idx}", fig, cur_epoch, close=True, ) plt.close("all") for thr in [5, 10, 20]: # log on all ranks for ModelCheckpoint callback to work properly self.log( f"auc@{thr}", torch.tensor(np.mean(multi_val_metrics[f"auc@{thr}"])) ) # ckpt monitors on this def test_step(self, batch, batch_idx): with self.profiler.profile("TopicFM"): self.matcher(batch) ret_dict, rel_pair_names = self._compute_metrics(batch) with self.profiler.profile("dump_results"): if self.dump_dir is not None: # dump results for further analysis keys_to_save = {"mkpts0_f", "mkpts1_f", "mconf", "epi_errs"} pair_names = list(zip(*batch["pair_names"])) bs = batch["image0"].shape[0] dumps = [] for b_id in range(bs): item = {} mask = batch["m_bids"] == b_id item["pair_names"] = pair_names[b_id] item["identifier"] = "#".join(rel_pair_names[b_id]) for key in keys_to_save: item[key] = batch[key][mask].cpu().numpy() for key in ["R_errs", "t_errs", "inliers"]: item[key] = batch[key][b_id] dumps.append(item) ret_dict["dumps"] = dumps return ret_dict def test_epoch_end(self, outputs): # metrics: dict of list, numpy _metrics = [o["metrics"] for o in outputs] metrics = { k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0] } # [{key: [{...}, *#bs]}, *#batch] if self.dump_dir is not None: Path(self.dump_dir).mkdir(parents=True, exist_ok=True) _dumps = flattenList([o["dumps"] for o in outputs]) # [{...}, #bs*#batch] dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch] logger.info( f"Prediction and evaluation results will be saved to: {self.dump_dir}" ) if self.trainer.global_rank == 0: print(self.profiler.summary()) val_metrics_4tb = aggregate_metrics( metrics, self.config.TRAINER.EPI_ERR_THR ) logger.info("\n" + pprint.pformat(val_metrics_4tb)) if self.dump_dir is not None: np.save(Path(self.dump_dir) / "TopicFM_pred_eval", dumps)