DNAI / dnafiber /metric.py
ClementP's picture
Upload 55 files
69591a9 verified
import kornia as K
import torch
import torchmetrics.functional as F
from skimage.measure import label
from torchmetrics import Metric
class DNAFIBERMetric(Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state(
"detection_tp",
default=torch.tensor(0, dtype=torch.int64),
dist_reduce_fx="sum",
)
self.add_state(
"fiber_red_dice",
default=torch.tensor(0, dtype=torch.float32),
dist_reduce_fx="sum",
)
self.add_state(
"fiber_green_dice",
default=torch.tensor(0, dtype=torch.float32),
dist_reduce_fx="sum",
)
self.add_state(
"fiber_red_recall",
default=torch.tensor(0, dtype=torch.float32),
dist_reduce_fx="sum",
)
self.add_state(
"fiber_green_recall",
default=torch.tensor(0, dtype=torch.float32),
dist_reduce_fx="sum",
)
# Specificity
self.add_state(
"fiber_red_precision",
default=torch.tensor(0, dtype=torch.float32),
dist_reduce_fx="sum",
)
self.add_state(
"fiber_green_precision",
default=torch.tensor(0, dtype=torch.float32),
dist_reduce_fx="sum",
)
self.add_state(
"detection_fp",
default=torch.tensor(0, dtype=torch.int64),
dist_reduce_fx="sum",
)
self.add_state(
"N",
default=torch.tensor(0, dtype=torch.int64),
dist_reduce_fx="sum",
)
def update(self, preds, target):
if preds.ndim == 4:
preds = preds.argmax(dim=1)
if target.ndim == 4:
target = target.squeeze(1)
B, H, W = preds.shape
preds_labels = []
target_labels = []
binary_preds = preds > 0
binary_target = target > 0
N_true_labels = 0
for i in range(B):
pred = binary_preds[i].detach().cpu().numpy()
target_np = binary_target[i].detach().cpu().numpy()
pred_labels = label(pred, connectivity=2)
target_labels_np = label(target_np, connectivity=2)
preds_labels.append(torch.from_numpy(pred_labels).to(preds.device))
target_labels.append(torch.from_numpy(target_labels_np).to(preds.device))
N_true_labels += target_labels_np.max()
preds_labels = torch.stack(preds_labels)
target_labels = torch.stack(target_labels)
for i, plab in enumerate(preds_labels):
labels = torch.unique(plab)
for blob in labels:
if blob == 0:
continue
pred_mask = plab == blob
pixels_in_common = torch.any(pred_mask & binary_target[i])
if pixels_in_common:
self.detection_tp += 1
gt_label = target_labels[i][pred_mask].unique()[-1]
gt_mask = target_labels[i] == gt_label
common_mask = pred_mask | gt_mask
pred_fiber = preds[i][common_mask]
gt_fiber = target[i][common_mask]
dices = F.dice(
pred_fiber,
gt_fiber,
num_classes=3,
ignore_index=0,
average=None,
)
dices = torch.nan_to_num(dices, nan=0.0)
self.fiber_red_dice += dices[1]
self.fiber_green_dice += dices[2]
recalls = F.recall(
pred_fiber,
gt_fiber,
num_classes=3,
ignore_index=0,
task="multiclass",
average=None,
)
recalls = torch.nan_to_num(recalls, nan=0.0)
self.fiber_red_recall += recalls[1]
self.fiber_green_recall += recalls[2]
# Specificity
specificity = F.precision(
pred_fiber,
gt_fiber,
num_classes=3,
ignore_index=0,
task="multiclass",
average=None,
)
specificity = torch.nan_to_num(specificity, nan=0.0)
self.fiber_red_precision += specificity[1]
self.fiber_green_precision += specificity[2]
else:
self.detection_fp += 1
self.N += N_true_labels
def compute(self):
return {
"detection_precision": self.detection_tp
/ (self.detection_tp + self.detection_fp + 1e-7),
"detection_recall": self.detection_tp / (self.N + 1e-7),
"fiber_red_dice": self.fiber_red_dice / (self.detection_tp + 1e-7),
"fiber_green_dice": self.fiber_green_dice / (self.detection_tp + 1e-7),
"fiber_red_recall": self.fiber_red_recall / (self.detection_tp + 1e-7),
"fiber_green_recall": self.fiber_green_recall / (self.detection_tp + 1e-7),
"fiber_red_precision": self.fiber_red_precision
/ (self.detection_tp + 1e-7),
"fiber_green_precision": self.fiber_green_precision
/ (self.detection_tp + 1e-7),
}