din0s's picture
Add code
d4ab5ac unverified
raw
history blame
No virus
1.95 kB
"""
File copied from
https://github.com/nicola-decao/diffmask/blob/master/diffmask/utils/util.py
"""
import torch
from torch import Tensor
def accuracy_precision_recall_f1(
y_pred: Tensor, y_true: Tensor, average: bool = True
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Calculates the accuracy, precision, recall and f1 score given the predicted and true labels.
Args:
y_pred (Tensor): predicted labels
y_true (Tensor): true labels
average (bool): whether to average the scores or not
Returns:
a tuple of the accuracy, precision, recall and f1 score
"""
M = confusion_matrix(y_pred, y_true)
tp = M.diagonal(dim1=-2, dim2=-1).float()
precision_den = M.sum(-2)
precision = torch.where(
precision_den == 0, torch.zeros_like(tp), tp / precision_den
)
recall_den = M.sum(-1)
recall = torch.where(recall_den == 0, torch.ones_like(tp), tp / recall_den)
f1_den = precision + recall
f1 = torch.where(
f1_den == 0, torch.zeros_like(tp), 2 * (precision * recall) / f1_den
)
# noinspection PyTypeChecker
return ((y_pred == y_true).float().mean(-1),) + (
tuple(e.mean(-1) for e in (precision, recall, f1))
if average
else (precision, recall, f1)
)
def confusion_matrix(y_pred: Tensor, y_true: Tensor) -> Tensor:
"""Creates a confusion matrix given the predicted and true labels."""
device = y_pred.device
labels = max(y_pred.max().item() + 1, y_true.max().item() + 1)
return (
(
torch.stack((y_true, y_pred), -1).unsqueeze(-2).unsqueeze(-2)
== torch.stack(
(
torch.arange(labels, device=device).unsqueeze(-1).repeat(1, labels),
torch.arange(labels, device=device).unsqueeze(-2).repeat(labels, 1),
),
-1,
)
)
.all(-1)
.sum(-3)
)