|
from collections import defaultdict |
|
import pprint |
|
from loguru import logger |
|
from pathlib import Path |
|
|
|
import torch |
|
import numpy as np |
|
import pytorch_lightning as pl |
|
from matplotlib import pyplot as plt |
|
|
|
from src.models import TopicFM |
|
from src.models.utils.supervision import ( |
|
compute_supervision_coarse, |
|
compute_supervision_fine, |
|
) |
|
from src.losses.loss import TopicFMLoss |
|
from src.optimizers import build_optimizer, build_scheduler |
|
from src.utils.metrics import ( |
|
compute_symmetrical_epipolar_errors, |
|
compute_pose_errors, |
|
aggregate_metrics, |
|
) |
|
from src.utils.plotting import make_matching_figures |
|
from src.utils.comm import gather, all_gather |
|
from src.utils.misc import lower_config, flattenList |
|
from src.utils.profiler import PassThroughProfiler |
|
|
|
|
|
class PL_Trainer(pl.LightningModule): |
|
def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None): |
|
""" |
|
TODO: |
|
- use the new version of PL logging API. |
|
""" |
|
super().__init__() |
|
|
|
self.config = config |
|
_config = lower_config(self.config) |
|
self.model_cfg = lower_config(_config["model"]) |
|
self.profiler = profiler or PassThroughProfiler() |
|
self.n_vals_plot = max( |
|
config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1 |
|
) |
|
|
|
|
|
self.matcher = TopicFM(config=_config["model"]) |
|
self.loss = TopicFMLoss(_config) |
|
|
|
|
|
if pretrained_ckpt: |
|
state_dict = torch.load(pretrained_ckpt, map_location="cpu")["state_dict"] |
|
self.matcher.load_state_dict(state_dict, strict=True) |
|
logger.info(f"Load '{pretrained_ckpt}' as pretrained checkpoint") |
|
|
|
|
|
self.dump_dir = dump_dir |
|
|
|
def configure_optimizers(self): |
|
|
|
optimizer = build_optimizer(self, self.config) |
|
scheduler = build_scheduler(self.config, optimizer) |
|
return [optimizer], [scheduler] |
|
|
|
def optimizer_step( |
|
self, |
|
epoch, |
|
batch_idx, |
|
optimizer, |
|
optimizer_idx, |
|
optimizer_closure, |
|
on_tpu, |
|
using_native_amp, |
|
using_lbfgs, |
|
): |
|
|
|
warmup_step = self.config.TRAINER.WARMUP_STEP |
|
if self.trainer.global_step < warmup_step: |
|
if self.config.TRAINER.WARMUP_TYPE == "linear": |
|
base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR |
|
lr = base_lr + ( |
|
self.trainer.global_step / self.config.TRAINER.WARMUP_STEP |
|
) * abs(self.config.TRAINER.TRUE_LR - base_lr) |
|
for pg in optimizer.param_groups: |
|
pg["lr"] = lr |
|
elif self.config.TRAINER.WARMUP_TYPE == "constant": |
|
pass |
|
else: |
|
raise ValueError( |
|
f"Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}" |
|
) |
|
|
|
|
|
optimizer.step(closure=optimizer_closure) |
|
optimizer.zero_grad() |
|
|
|
def _trainval_inference(self, batch): |
|
with self.profiler.profile("Compute coarse supervision"): |
|
compute_supervision_coarse(batch, self.config) |
|
|
|
with self.profiler.profile("TopicFM"): |
|
self.matcher(batch) |
|
|
|
with self.profiler.profile("Compute fine supervision"): |
|
compute_supervision_fine(batch, self.config) |
|
|
|
with self.profiler.profile("Compute losses"): |
|
self.loss(batch) |
|
|
|
def _compute_metrics(self, batch): |
|
with self.profiler.profile("Copmute metrics"): |
|
compute_symmetrical_epipolar_errors( |
|
batch |
|
) |
|
compute_pose_errors( |
|
batch, self.config |
|
) |
|
|
|
rel_pair_names = list(zip(*batch["pair_names"])) |
|
bs = batch["image0"].size(0) |
|
metrics = { |
|
|
|
"identifiers": ["#".join(rel_pair_names[b]) for b in range(bs)], |
|
"epi_errs": [ |
|
batch["epi_errs"][batch["m_bids"] == b].cpu().numpy() |
|
for b in range(bs) |
|
], |
|
"R_errs": batch["R_errs"], |
|
"t_errs": batch["t_errs"], |
|
"inliers": batch["inliers"], |
|
} |
|
ret_dict = {"metrics": metrics} |
|
return ret_dict, rel_pair_names |
|
|
|
def training_step(self, batch, batch_idx): |
|
self._trainval_inference(batch) |
|
|
|
|
|
if ( |
|
self.trainer.global_rank == 0 |
|
and self.global_step % self.trainer.log_every_n_steps == 0 |
|
): |
|
|
|
for k, v in batch["loss_scalars"].items(): |
|
self.logger.experiment.add_scalar(f"train/{k}", v, self.global_step) |
|
|
|
|
|
if self.config.TRAINER.ENABLE_PLOTTING: |
|
compute_symmetrical_epipolar_errors( |
|
batch |
|
) |
|
figures = make_matching_figures( |
|
batch, self.config, self.config.TRAINER.PLOT_MODE |
|
) |
|
for k, v in figures.items(): |
|
self.logger.experiment.add_figure( |
|
f"train_match/{k}", v, self.global_step |
|
) |
|
|
|
return {"loss": batch["loss"]} |
|
|
|
def training_epoch_end(self, outputs): |
|
avg_loss = torch.stack([x["loss"] for x in outputs]).mean() |
|
if self.trainer.global_rank == 0: |
|
self.logger.experiment.add_scalar( |
|
"train/avg_loss_on_epoch", avg_loss, global_step=self.current_epoch |
|
) |
|
|
|
def validation_step(self, batch, batch_idx): |
|
self._trainval_inference(batch) |
|
|
|
ret_dict, _ = self._compute_metrics(batch) |
|
|
|
val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1) |
|
figures = {self.config.TRAINER.PLOT_MODE: []} |
|
if batch_idx % val_plot_interval == 0: |
|
figures = make_matching_figures( |
|
batch, self.config, mode=self.config.TRAINER.PLOT_MODE |
|
) |
|
|
|
return { |
|
**ret_dict, |
|
"loss_scalars": batch["loss_scalars"], |
|
"figures": figures, |
|
} |
|
|
|
def validation_epoch_end(self, outputs): |
|
|
|
multi_outputs = ( |
|
[outputs] if not isinstance(outputs[0], (list, tuple)) else outputs |
|
) |
|
multi_val_metrics = defaultdict(list) |
|
|
|
for valset_idx, outputs in enumerate(multi_outputs): |
|
|
|
cur_epoch = self.trainer.current_epoch |
|
if ( |
|
not self.trainer.resume_from_checkpoint |
|
and self.trainer.running_sanity_check |
|
): |
|
cur_epoch = -1 |
|
|
|
|
|
_loss_scalars = [o["loss_scalars"] for o in outputs] |
|
loss_scalars = { |
|
k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) |
|
for k in _loss_scalars[0] |
|
} |
|
|
|
|
|
_metrics = [o["metrics"] for o in outputs] |
|
metrics = { |
|
k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) |
|
for k in _metrics[0] |
|
} |
|
|
|
val_metrics_4tb = aggregate_metrics( |
|
metrics, self.config.TRAINER.EPI_ERR_THR |
|
) |
|
for thr in [5, 10, 20]: |
|
multi_val_metrics[f"auc@{thr}"].append(val_metrics_4tb[f"auc@{thr}"]) |
|
|
|
|
|
_figures = [o["figures"] for o in outputs] |
|
figures = { |
|
k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) |
|
for k in _figures[0] |
|
} |
|
|
|
|
|
if self.trainer.global_rank == 0: |
|
for k, v in loss_scalars.items(): |
|
mean_v = torch.stack(v).mean() |
|
self.logger.experiment.add_scalar( |
|
f"val_{valset_idx}/avg_{k}", mean_v, global_step=cur_epoch |
|
) |
|
|
|
for k, v in val_metrics_4tb.items(): |
|
self.logger.experiment.add_scalar( |
|
f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch |
|
) |
|
|
|
for k, v in figures.items(): |
|
if self.trainer.global_rank == 0: |
|
for plot_idx, fig in enumerate(v): |
|
self.logger.experiment.add_figure( |
|
f"val_match_{valset_idx}/{k}/pair-{plot_idx}", |
|
fig, |
|
cur_epoch, |
|
close=True, |
|
) |
|
plt.close("all") |
|
|
|
for thr in [5, 10, 20]: |
|
|
|
self.log( |
|
f"auc@{thr}", torch.tensor(np.mean(multi_val_metrics[f"auc@{thr}"])) |
|
) |
|
|
|
def test_step(self, batch, batch_idx): |
|
with self.profiler.profile("TopicFM"): |
|
self.matcher(batch) |
|
|
|
ret_dict, rel_pair_names = self._compute_metrics(batch) |
|
|
|
with self.profiler.profile("dump_results"): |
|
if self.dump_dir is not None: |
|
|
|
keys_to_save = {"mkpts0_f", "mkpts1_f", "mconf", "epi_errs"} |
|
pair_names = list(zip(*batch["pair_names"])) |
|
bs = batch["image0"].shape[0] |
|
dumps = [] |
|
for b_id in range(bs): |
|
item = {} |
|
mask = batch["m_bids"] == b_id |
|
item["pair_names"] = pair_names[b_id] |
|
item["identifier"] = "#".join(rel_pair_names[b_id]) |
|
for key in keys_to_save: |
|
item[key] = batch[key][mask].cpu().numpy() |
|
for key in ["R_errs", "t_errs", "inliers"]: |
|
item[key] = batch[key][b_id] |
|
dumps.append(item) |
|
ret_dict["dumps"] = dumps |
|
|
|
return ret_dict |
|
|
|
def test_epoch_end(self, outputs): |
|
|
|
_metrics = [o["metrics"] for o in outputs] |
|
metrics = { |
|
k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) |
|
for k in _metrics[0] |
|
} |
|
|
|
|
|
if self.dump_dir is not None: |
|
Path(self.dump_dir).mkdir(parents=True, exist_ok=True) |
|
_dumps = flattenList([o["dumps"] for o in outputs]) |
|
dumps = flattenList(gather(_dumps)) |
|
logger.info( |
|
f"Prediction and evaluation results will be saved to: {self.dump_dir}" |
|
) |
|
|
|
if self.trainer.global_rank == 0: |
|
print(self.profiler.summary()) |
|
val_metrics_4tb = aggregate_metrics( |
|
metrics, self.config.TRAINER.EPI_ERR_THR |
|
) |
|
logger.info("\n" + pprint.pformat(val_metrics_4tb)) |
|
if self.dump_dir is not None: |
|
np.save(Path(self.dump_dir) / "TopicFM_pred_eval", dumps) |
|
|