"""Save a side-by-side overlay: input | ground-truth | prediction. Used at test time to qualitatively inspect each method's output. Denormalizes the input tensor back to a viewable image and color-codes class masks. """ from __future__ import annotations import numpy as np import cv2 import torch _IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) _IMAGENET_STD = np.array([0.229, 0.224, 0.225]) # distinct colors for up to 6 classes (0 = background -> transparent/black) _PALETTE = np.array([ [0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [255, 0, 255], [0, 255, 255], ], dtype=np.uint8) def _denorm(img: torch.Tensor) -> np.ndarray: x = img.float().numpy() # C,H,W c = x.shape[0] x = np.transpose(x, (1, 2, 0)) # H,W,C if c == 3: x = x * _IMAGENET_STD + _IMAGENET_MEAN else: x = x * 0.5 + 0.5 x = np.repeat(x, 3, axis=2) if x.shape[2] == 1 else x x = np.clip(x * 255.0, 0, 255).astype(np.uint8) return x def _colorize(mask: np.ndarray, num_classes: int) -> np.ndarray: h, w = mask.shape out = np.zeros((h, w, 3), dtype=np.uint8) for c in range(1, num_classes): out[mask == c] = _PALETTE[c % len(_PALETTE)] return out def save_overlay(img: torch.Tensor, gt: np.ndarray, pred: np.ndarray, num_classes: int, path: str, alpha: float = 0.5) -> None: base = _denorm(img) h, w = gt.shape base = cv2.resize(base, (w, h), interpolation=cv2.INTER_LINEAR) gt_c = _colorize(gt, num_classes) pr_c = _colorize(pred, num_classes) gt_o = cv2.addWeighted(base, 1 - alpha, gt_c, alpha, 0) pr_o = cv2.addWeighted(base, 1 - alpha, pr_c, alpha, 0) panel = np.concatenate([base, gt_o, pr_o], axis=1) cv2.imwrite(path, cv2.cvtColor(panel, cv2.COLOR_RGB2BGR))