File size: 409 Bytes
e5765b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch
EPSILON = 1e-15
def binary_mean_iou(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
output = (logits > 0).int()
if output.shape != targets.shape:
targets = torch.squeeze(targets, 1)
intersection = (targets * output).sum()
union = targets.sum() + output.sum() - intersection
result = (intersection + EPSILON) / (union + EPSILON)
return result
|