|
import torch |
|
import pytorch_lightning as pl |
|
from pathlib import Path |
|
from typing import Any |
|
import torchvision |
|
import wandb |
|
|
|
|
|
class EvalSaveCallback(pl.Callback): |
|
|
|
def __init__(self, save_dir: Path) -> None: |
|
super().__init__() |
|
self.save_dir = save_dir |
|
|
|
def save(self, outputs, batch, batch_idx): |
|
name = batch['name'] |
|
|
|
filename = self.save_dir / f"{batch_idx:06d}_{name[0]}.pt" |
|
torch.save({ |
|
"fpv": batch['image'], |
|
"seg_masks": batch['seg_masks'], |
|
'name': name, |
|
"output": outputs["output"], |
|
"valid_bev": outputs["valid_bev"], |
|
}, filename) |
|
|
|
def on_test_batch_end(self, trainer: pl.Trainer, |
|
pl_module: pl.LightningModule, |
|
outputs: torch.Tensor | Any | None, |
|
batch: Any, |
|
batch_idx: int, |
|
dataloader_idx: int = 0) -> None: |
|
if not outputs: |
|
return |
|
|
|
self.save(outputs, batch, batch_idx) |
|
|
|
def on_validation_batch_end(self, trainer: pl.Trainer, |
|
pl_module: pl.LightningModule, |
|
outputs: torch.Tensor | Any | None, |
|
batch: Any, |
|
batch_idx: int, |
|
dataloader_idx: int = 0) -> None: |
|
if not outputs: |
|
|
|
return |
|
|
|
self.save(outputs, batch, batch_idx) |
|
|
|
|
|
class ImageLoggerCallback(pl.Callback): |
|
def __init__(self, num_classes): |
|
super().__init__() |
|
self.num_classes = num_classes |
|
|
|
def log_image(self, trainer, pl_module, outputs, batch, batch_idx, mode="train"): |
|
fpv_rgb = batch["image"] |
|
fpv_grid = torchvision.utils.make_grid( |
|
fpv_rgb, nrow=8, normalize=False) |
|
images = [ |
|
wandb.Image(fpv_grid, caption="fpv") |
|
] |
|
|
|
pred = outputs['output'].permute(0, 2, 3, 1) |
|
pred[outputs["valid_bev"][..., :-1] == 0] = 0 |
|
pred = (pred > 0.5).float() |
|
pred = pred.permute(0, 3, 1, 2) |
|
|
|
for i in range(self.num_classes): |
|
gt_class_i = batch['seg_masks'][..., i] |
|
gt_class_i_grid = torchvision.utils.make_grid( |
|
gt_class_i.unsqueeze(1), nrow=8, normalize=False, pad_value=0) |
|
pred_class_i = pred[:, i] |
|
pred_class_i_grid = torchvision.utils.make_grid( |
|
pred_class_i.unsqueeze(1), nrow=8, normalize=False, pad_value=0) |
|
|
|
images += [ |
|
wandb.Image(gt_class_i_grid, caption=f"gt_class_{i}"), |
|
wandb.Image(pred_class_i_grid, caption=f"pred_class_{i}") |
|
] |
|
|
|
trainer.logger.experiment.log( |
|
{ |
|
"{}/images".format(mode): images |
|
} |
|
) |
|
|
|
def on_validation_batch_end(self, trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx): |
|
if batch_idx == 0: |
|
with torch.no_grad(): |
|
outputs = pl_module(batch) |
|
self.log_image(trainer, pl_module, outputs, |
|
batch, batch_idx, mode="val") |
|
|
|
def on_train_batch_end(self, trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx): |
|
if batch_idx == 0: |
|
pl_module.eval() |
|
|
|
with torch.no_grad(): |
|
outputs = pl_module(batch) |
|
|
|
self.log_image(trainer, pl_module, outputs, |
|
batch, batch_idx, mode="train") |
|
|
|
pl_module.train() |
|
|