|
import kornia as K
|
|
import torch
|
|
import torchmetrics.functional as F
|
|
from skimage.measure import label
|
|
from torchmetrics import Metric
|
|
|
|
|
|
class DNAFIBERMetric(Metric):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
self.add_state(
|
|
"detection_tp",
|
|
default=torch.tensor(0, dtype=torch.int64),
|
|
dist_reduce_fx="sum",
|
|
)
|
|
self.add_state(
|
|
"fiber_red_dice",
|
|
default=torch.tensor(0, dtype=torch.float32),
|
|
dist_reduce_fx="sum",
|
|
)
|
|
self.add_state(
|
|
"fiber_green_dice",
|
|
default=torch.tensor(0, dtype=torch.float32),
|
|
dist_reduce_fx="sum",
|
|
)
|
|
self.add_state(
|
|
"fiber_red_recall",
|
|
default=torch.tensor(0, dtype=torch.float32),
|
|
dist_reduce_fx="sum",
|
|
)
|
|
self.add_state(
|
|
"fiber_green_recall",
|
|
default=torch.tensor(0, dtype=torch.float32),
|
|
dist_reduce_fx="sum",
|
|
)
|
|
|
|
self.add_state(
|
|
"fiber_red_precision",
|
|
default=torch.tensor(0, dtype=torch.float32),
|
|
dist_reduce_fx="sum",
|
|
)
|
|
self.add_state(
|
|
"fiber_green_precision",
|
|
default=torch.tensor(0, dtype=torch.float32),
|
|
dist_reduce_fx="sum",
|
|
)
|
|
|
|
self.add_state(
|
|
"detection_fp",
|
|
default=torch.tensor(0, dtype=torch.int64),
|
|
dist_reduce_fx="sum",
|
|
)
|
|
self.add_state(
|
|
"N",
|
|
default=torch.tensor(0, dtype=torch.int64),
|
|
dist_reduce_fx="sum",
|
|
)
|
|
|
|
def update(self, preds, target):
|
|
if preds.ndim == 4:
|
|
preds = preds.argmax(dim=1)
|
|
if target.ndim == 4:
|
|
target = target.squeeze(1)
|
|
B, H, W = preds.shape
|
|
preds_labels = []
|
|
target_labels = []
|
|
binary_preds = preds > 0
|
|
binary_target = target > 0
|
|
N_true_labels = 0
|
|
for i in range(B):
|
|
pred = binary_preds[i].detach().cpu().numpy()
|
|
target_np = binary_target[i].detach().cpu().numpy()
|
|
pred_labels = label(pred, connectivity=2)
|
|
target_labels_np = label(target_np, connectivity=2)
|
|
preds_labels.append(torch.from_numpy(pred_labels).to(preds.device))
|
|
target_labels.append(torch.from_numpy(target_labels_np).to(preds.device))
|
|
N_true_labels += target_labels_np.max()
|
|
|
|
preds_labels = torch.stack(preds_labels)
|
|
target_labels = torch.stack(target_labels)
|
|
|
|
for i, plab in enumerate(preds_labels):
|
|
labels = torch.unique(plab)
|
|
for blob in labels:
|
|
if blob == 0:
|
|
continue
|
|
pred_mask = plab == blob
|
|
pixels_in_common = torch.any(pred_mask & binary_target[i])
|
|
if pixels_in_common:
|
|
self.detection_tp += 1
|
|
gt_label = target_labels[i][pred_mask].unique()[-1]
|
|
gt_mask = target_labels[i] == gt_label
|
|
common_mask = pred_mask | gt_mask
|
|
pred_fiber = preds[i][common_mask]
|
|
gt_fiber = target[i][common_mask]
|
|
dices = F.dice(
|
|
pred_fiber,
|
|
gt_fiber,
|
|
num_classes=3,
|
|
ignore_index=0,
|
|
average=None,
|
|
)
|
|
dices = torch.nan_to_num(dices, nan=0.0)
|
|
self.fiber_red_dice += dices[1]
|
|
self.fiber_green_dice += dices[2]
|
|
recalls = F.recall(
|
|
pred_fiber,
|
|
gt_fiber,
|
|
num_classes=3,
|
|
ignore_index=0,
|
|
task="multiclass",
|
|
average=None,
|
|
)
|
|
recalls = torch.nan_to_num(recalls, nan=0.0)
|
|
self.fiber_red_recall += recalls[1]
|
|
self.fiber_green_recall += recalls[2]
|
|
|
|
|
|
specificity = F.precision(
|
|
pred_fiber,
|
|
gt_fiber,
|
|
num_classes=3,
|
|
ignore_index=0,
|
|
task="multiclass",
|
|
average=None,
|
|
)
|
|
specificity = torch.nan_to_num(specificity, nan=0.0)
|
|
self.fiber_red_precision += specificity[1]
|
|
self.fiber_green_precision += specificity[2]
|
|
|
|
else:
|
|
self.detection_fp += 1
|
|
|
|
self.N += N_true_labels
|
|
|
|
def compute(self):
|
|
return {
|
|
"detection_precision": self.detection_tp
|
|
/ (self.detection_tp + self.detection_fp + 1e-7),
|
|
"detection_recall": self.detection_tp / (self.N + 1e-7),
|
|
"fiber_red_dice": self.fiber_red_dice / (self.detection_tp + 1e-7),
|
|
"fiber_green_dice": self.fiber_green_dice / (self.detection_tp + 1e-7),
|
|
"fiber_red_recall": self.fiber_red_recall / (self.detection_tp + 1e-7),
|
|
"fiber_green_recall": self.fiber_green_recall / (self.detection_tp + 1e-7),
|
|
"fiber_red_precision": self.fiber_red_precision
|
|
/ (self.detection_tp + 1e-7),
|
|
"fiber_green_precision": self.fiber_green_precision
|
|
/ (self.detection_tp + 1e-7),
|
|
}
|
|
|