import torch | |
EPSILON = 1e-15 | |
def binary_mean_iou(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | |
output = (logits > 0).int() | |
if output.shape != targets.shape: | |
targets = torch.squeeze(targets, 1) | |
intersection = (targets * output).sum() | |
union = targets.sum() + output.sum() - intersection | |
result = (intersection + EPSILON) / (union + EPSILON) | |
return result | |