import numpy as np import numpy.typing as npt from torch.nn import functional as F from torch import Tensor def calculate_dice_loss(inputs: Tensor, targets: Tensor, num_masks: int = 1) -> Tensor: inputs = inputs.sigmoid() inputs = inputs.flatten(1) numerator = 2 * (inputs * targets).sum(-1) denominator = inputs.sum(-1) + targets.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) return loss.sum() / num_masks def calculate_sigmoid_focal_loss( inputs: Tensor, targets: Tensor, num_masks: int = 1, alpha: float = 0.25, gamma: float = 2, ) -> Tensor: prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = prob * targets + (1 - prob) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss return loss.mean(1).sum() / num_masks def calculate_iou(mask1: npt.NDArray, mask2: npt.NDArray) -> float: mask1 = mask1.sum(axis=2) mask2 = mask2.sum(axis=2) mask1 = np.where(mask1 == 128, 1, 0) mask2 = np.where(mask2 == 128, 1, 0) intersection = np.sum(np.logical_and(mask1, mask2)) union = np.sum(np.logical_or(mask1, mask2)) iou = intersection / union return iou