|
""" |
|
Code adapted from SelfMask: https://github.com/NoelShin/selfmask |
|
""" |
|
|
|
from typing import Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
def compute_iou( |
|
pred_mask: Union[np.ndarray, torch.Tensor], |
|
gt_mask: Union[np.ndarray, torch.Tensor], |
|
threshold: Optional[float] = 0.5, |
|
eps: float = 1e-7, |
|
) -> Union[np.ndarray, torch.Tensor]: |
|
""" |
|
:param pred_mask: (B x H x W) or (H x W) |
|
:param gt_mask: (B x H x W) or (H x W), same shape with pred_mask |
|
:param threshold: a binarization threshold |
|
:param eps: a small value for computational stability |
|
:return: (B) or (1) |
|
""" |
|
assert pred_mask.shape == gt_mask.shape, f"{pred_mask.shape} != {gt_mask.shape}" |
|
|
|
|
|
if threshold is not None: |
|
pred_mask = pred_mask > threshold |
|
if isinstance(pred_mask, np.ndarray): |
|
intersection = np.logical_and(pred_mask, gt_mask).sum(axis=(-1, -2)) |
|
union = np.logical_or(pred_mask, gt_mask).sum(axis=(-1, -2)) |
|
ious = intersection / (union + eps) |
|
else: |
|
intersection = torch.logical_and(pred_mask, gt_mask).sum(dim=(-1, -2)) |
|
union = torch.logical_or(pred_mask, gt_mask).sum(dim=(-1, -2)) |
|
ious = (intersection / (union + eps)).cpu() |
|
return ious |
|
|