import torch import pytorch_lightning as pl from pathlib import Path from typing import Any import torchvision import wandb class EvalSaveCallback(pl.Callback): def __init__(self, save_dir: Path) -> None: super().__init__() self.save_dir = save_dir def save(self, outputs, batch, batch_idx): name = batch['name'] filename = self.save_dir / f"{batch_idx:06d}_{name[0]}.pt" torch.save({ "fpv": batch['image'], "seg_masks": batch['seg_masks'], 'name': name, "output": outputs["output"], "valid_bev": outputs["valid_bev"], }, filename) def on_test_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: torch.Tensor | Any | None, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: if not outputs: return self.save(outputs, batch, batch_idx) def on_validation_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: torch.Tensor | Any | None, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: if not outputs: return self.save(outputs, batch, batch_idx) class ImageLoggerCallback(pl.Callback): def __init__(self, num_classes): super().__init__() self.num_classes = num_classes def log_image(self, trainer, pl_module, outputs, batch, batch_idx, mode="train"): fpv_rgb = batch["image"] fpv_grid = torchvision.utils.make_grid( fpv_rgb, nrow=8, normalize=False) images = [ wandb.Image(fpv_grid, caption="fpv") ] pred = outputs['output'].permute(0, 2, 3, 1) pred[outputs["valid_bev"][..., :-1] == 0] = 0 pred = (pred > 0.5).float() pred = pred.permute(0, 3, 1, 2) for i in range(self.num_classes): gt_class_i = batch['seg_masks'][..., i] gt_class_i_grid = torchvision.utils.make_grid( gt_class_i.unsqueeze(1), nrow=8, normalize=False, pad_value=0) pred_class_i = pred[:, i] pred_class_i_grid = torchvision.utils.make_grid( pred_class_i.unsqueeze(1), nrow=8, normalize=False, pad_value=0) images += [ wandb.Image(gt_class_i_grid, caption=f"gt_class_{i}"), wandb.Image(pred_class_i_grid, caption=f"pred_class_{i}") ] trainer.logger.experiment.log( { "{}/images".format(mode): images } ) def on_validation_batch_end(self, trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx): if batch_idx == 0: with torch.no_grad(): outputs = pl_module(batch) self.log_image(trainer, pl_module, outputs, batch, batch_idx, mode="val") def on_train_batch_end(self, trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx): if batch_idx == 0: pl_module.eval() with torch.no_grad(): outputs = pl_module(batch) self.log_image(trainer, pl_module, outputs, batch, batch_idx, mode="train") pl_module.train()