hasibzunair's picture
add files
1803579
raw
history blame
4.07 kB
"""
Code borrowed from SelfMask: https://github.com/NoelShin/selfmask
"""
import torch
class FMeasure:
def __init__(
self,
default_thres: float = 0.5,
beta_square: float = 0.3,
n_bins: int = 255,
eps: float = 1e-7,
):
"""
:param default_thres: a hyperparameter for F-measure that is used to binarize a predicted mask. Default: 0.5
:param beta_square: a hyperparameter for F-measure. Default: 0.3
:param n_bins: the number of thresholds that will be tested for F-max. Default: 255
:param eps: a small value for numerical stability
"""
self.beta_square = beta_square
self.default_thres = default_thres
self.eps = eps
self.n_bins = n_bins
def _compute_precision_recall(
self, binary_pred_mask: torch.Tensor, gt_mask: torch.Tensor
) -> torch.Tensor:
"""
:param binary_pred_mask: (B x H x W) or (H x W)
:param gt_mask: (B x H x W) or (H x W), should be the same with binary_pred_mask
"""
tp = torch.logical_and(binary_pred_mask, gt_mask).sum(dim=(-1, -2))
tp_fp = binary_pred_mask.sum(dim=(-1, -2))
tp_fn = gt_mask.sum(dim=(-1, -2))
prec = tp / (tp_fp + self.eps)
recall = tp / (tp_fn + self.eps)
return prec, recall
def _compute_f_measure(
self,
pred_mask: torch.Tensor,
gt_mask: torch.Tensor,
thresholds: torch.Tensor = None,
) -> torch.Tensor:
if thresholds is None:
binary_pred_mask = pred_mask > self.default_thres
else:
binary_pred_mask = pred_mask > thresholds
prec, recall = self._compute_precision_recall(binary_pred_mask, gt_mask)
f_measure = ((1 + (self.beta_square**2)) * prec * recall) / (
(self.beta_square**2) * prec + recall + self.eps
)
return f_measure.cpu()
def _compute_f_max(
self, pred_mask: torch.Tensor, gt_mask: torch.Tensor
) -> torch.Tensor:
"""Compute self.n_bins + 1 F-measures, each of which has a different threshold, then return the maximum
F-measure among them.
:param pred_mask: (H x W)
:param gt_mask: (H x W)
"""
# pred_masks, gt_masks: H x W -> self.n_bins x H x W
pred_masks = pred_mask.unsqueeze(dim=0).repeat(self.n_bins, 1, 1)
gt_masks = gt_mask.unsqueeze(dim=0).repeat(self.n_bins, 1, 1)
# thresholds: self.n_bins x 1 x 1
thresholds = (
torch.arange(0, 1, 1 / self.n_bins)
.view(self.n_bins, 1, 1)
.to(pred_masks.device)
)
# f_measures: self.n_bins
f_measures = self._compute_f_measure(pred_masks, gt_masks, thresholds)
return torch.max(f_measures).cpu(), f_measures
def _compute_f_mean(
self,
pred_mask: torch.Tensor,
gt_mask: torch.Tensor,
) -> torch.Tensor:
adaptive_thres = 2 * pred_mask.mean(dim=(-1, -2), keepdim=True)
binary_pred_mask = pred_mask > adaptive_thres
prec, recall = self._compute_precision_recall(binary_pred_mask, gt_mask)
f_mean = ((1 + (self.beta_square**2)) * prec * recall) / (
(self.beta_square**2) * prec + recall + self.eps
)
return f_mean.cpu()
def __call__(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> dict:
"""
:param pred_mask: (H x W) a normalized prediction mask with values in [0, 1]
:param gt_mask: (H x W) a binary ground truth mask with values in {0, 1}
:return: a dictionary with keys being "f_measure" and "f_max" and values being the respective values.
"""
outputs: dict = dict()
for k in ("f_measure", "f_mean"):
outputs.update({k: getattr(self, f"_compute_{k}")(pred_mask, gt_mask)})
f_max_, all_f = self._compute_f_max(pred_mask, gt_mask)
outputs["f_max"] = f_max_
outputs["all_f"] = all_f # List of all f values for all thresholds
return outputs