|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from typing import TYPE_CHECKING, Callable, Optional |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed |
|
from monai.config import IgniteInfo |
|
from monai.utils import min_version, optional_import |
|
|
|
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") |
|
make_grid, _ = optional_import("torchvision.utils", name="make_grid") |
|
Image, _ = optional_import("PIL.Image") |
|
ImageDraw, _ = optional_import("PIL.ImageDraw") |
|
|
|
if TYPE_CHECKING: |
|
from ignite.engine import Engine |
|
from torch.utils.tensorboard import SummaryWriter |
|
else: |
|
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") |
|
SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter") |
|
|
|
|
|
class TensorBoardImageHandler: |
|
def __init__( |
|
self, |
|
summary_writer: Optional[SummaryWriter] = None, |
|
log_dir: str = "./runs", |
|
tag_name="val", |
|
interval: int = 1, |
|
batch_transform: Callable = lambda x: x, |
|
output_transform: Callable = lambda x: x, |
|
batch_limit=1, |
|
device=None, |
|
) -> None: |
|
self.writer = SummaryWriter(log_dir=log_dir) if summary_writer is None else summary_writer |
|
self.tag_name = tag_name |
|
self.interval = interval |
|
self.batch_transform = batch_transform |
|
self.output_transform = output_transform |
|
self.batch_limit = batch_limit |
|
self.device = device |
|
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
if torch.distributed.is_initialized(): |
|
self.tag_name = f"{self.tag_name}-r{torch.distributed.get_rank()}" |
|
|
|
def attach(self, engine: Engine) -> None: |
|
engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.interval), self, "epoch") |
|
|
|
def __call__(self, engine: Engine, action) -> None: |
|
epoch = engine.state.epoch |
|
batch_data = self.batch_transform(engine.state.batch) |
|
output_data = self.output_transform(engine.state.output) |
|
|
|
self.write_images(batch_data, output_data, epoch) |
|
|
|
def write_images(self, batch_data, output_data, epoch): |
|
for bidx in range(len(batch_data)): |
|
image = batch_data[bidx]["image"].detach().cpu().numpy() |
|
y = output_data[bidx]["label"].detach().cpu().numpy() |
|
|
|
tag_prefix = f"{self.tag_name} - b{bidx} - " if self.batch_limit != 1 else "" |
|
img_np = image[:3] |
|
img_np[0, :, :] = np.where(y[0] > 0, 1, img_np[0, :, :]) |
|
img_tensor = make_grid(torch.from_numpy(img_np), normalize=True) |
|
self.writer.add_image(tag=f"{tag_prefix}Image", img_tensor=img_tensor, global_step=epoch) |
|
|
|
y_pred = output_data[bidx]["pred"].detach().cpu().numpy() |
|
|
|
cl = np.count_nonzero(y) |
|
cp = np.count_nonzero(y_pred) |
|
self.logger.info( |
|
"{} => {} - Image: {};" |
|
" Label: {} (nz: {});" |
|
" Pred: {} (nz: {});" |
|
" Diff: {:.2f}%;" |
|
" Sig: (pos-nz: {}, neg-nz: {})".format( |
|
self.tag_name, |
|
bidx, |
|
image.shape, |
|
y.shape, |
|
cl, |
|
y_pred.shape, |
|
cp, |
|
100 * (cp - cl) / (cl + 1), |
|
np.count_nonzero(image[3]), |
|
np.count_nonzero(image[4]), |
|
) |
|
) |
|
|
|
tag_prefix = f"{self.tag_name} - b{bidx} - " if self.batch_limit != 1 else f"{self.tag_name} - " |
|
label_pred = [y, y_pred, image[3][None] > 0, image[4][None] > 0] |
|
label_pred_tag = f"{tag_prefix}Label vs Pred vs Pos vs Neg" |
|
|
|
img_tensor = make_grid(tensor=torch.from_numpy(np.array(label_pred)), nrow=4, normalize=True, pad_value=10) |
|
self.writer.add_image(tag=label_pred_tag, img_tensor=img_tensor, global_step=epoch) |
|
|
|
if self.batch_limit == 1 or bidx == (self.batch_limit - 1): |
|
break |
|
|