MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
1.84 kB
"""Save a side-by-side overlay: input | ground-truth | prediction.
Used at test time to qualitatively inspect each method's output. Denormalizes the
input tensor back to a viewable image and color-codes class masks.
"""
from __future__ import annotations
import numpy as np
import cv2
import torch
_IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
_IMAGENET_STD = np.array([0.229, 0.224, 0.225])
# distinct colors for up to 6 classes (0 = background -> transparent/black)
_PALETTE = np.array([
[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255],
[255, 255, 0], [255, 0, 255], [0, 255, 255],
], dtype=np.uint8)
def _denorm(img: torch.Tensor) -> np.ndarray:
x = img.float().numpy() # C,H,W
c = x.shape[0]
x = np.transpose(x, (1, 2, 0)) # H,W,C
if c == 3:
x = x * _IMAGENET_STD + _IMAGENET_MEAN
else:
x = x * 0.5 + 0.5
x = np.repeat(x, 3, axis=2) if x.shape[2] == 1 else x
x = np.clip(x * 255.0, 0, 255).astype(np.uint8)
return x
def _colorize(mask: np.ndarray, num_classes: int) -> np.ndarray:
h, w = mask.shape
out = np.zeros((h, w, 3), dtype=np.uint8)
for c in range(1, num_classes):
out[mask == c] = _PALETTE[c % len(_PALETTE)]
return out
def save_overlay(img: torch.Tensor, gt: np.ndarray, pred: np.ndarray,
num_classes: int, path: str, alpha: float = 0.5) -> None:
base = _denorm(img)
h, w = gt.shape
base = cv2.resize(base, (w, h), interpolation=cv2.INTER_LINEAR)
gt_c = _colorize(gt, num_classes)
pr_c = _colorize(pred, num_classes)
gt_o = cv2.addWeighted(base, 1 - alpha, gt_c, alpha, 0)
pr_o = cv2.addWeighted(base, 1 - alpha, pr_c, alpha, 0)
panel = np.concatenate([base, gt_o, pr_o], axis=1)
cv2.imwrite(path, cv2.cvtColor(panel, cv2.COLOR_RGB2BGR))