dillonlaird's picture
initial commit
6723494
raw
history blame
No virus
1.33 kB
import numpy as np
import numpy.typing as npt
from torch.nn import functional as F
from torch import Tensor
def calculate_dice_loss(inputs: Tensor, targets: Tensor, num_masks: int = 1) -> Tensor:
inputs = inputs.sigmoid()
inputs = inputs.flatten(1)
numerator = 2 * (inputs * targets).sum(-1)
denominator = inputs.sum(-1) + targets.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
return loss.sum() / num_masks
def calculate_sigmoid_focal_loss(
inputs: Tensor,
targets: Tensor,
num_masks: int = 1,
alpha: float = 0.25,
gamma: float = 2,
) -> Tensor:
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
return loss.mean(1).sum() / num_masks
def calculate_iou(mask1: npt.NDArray, mask2: npt.NDArray) -> float:
mask1 = mask1.sum(axis=2)
mask2 = mask2.sum(axis=2)
mask1 = np.where(mask1 == 128, 1, 0)
mask2 = np.where(mask2 == 128, 1, 0)
intersection = np.sum(np.logical_and(mask1, mask2))
union = np.sum(np.logical_or(mask1, mask2))
iou = intersection / union
return iou