Spaces:
Sleeping
Sleeping
File size: 3,182 Bytes
99a05f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import numpy as np
import torch
import monai.metrics as metrics
from common.constants import DIST_MATRIX_PATH
DIST_MATRIX = np.load(DIST_MATRIX_PATH)
def metric(mask, pred, back=True):
iou = metrics.compute_meaniou(pred, mask, back, False)
iou = iou.mean()
return iou
def precision_recall_f1score(gt, pred):
"""
Compute precision, recall, and f1
"""
# gt = gt.numpy()
# pred = pred.numpy()
precision = torch.zeros(gt.shape[0])
recall = torch.zeros(gt.shape[0])
f1 = torch.zeros(gt.shape[0])
for b in range(gt.shape[0]):
tp_num = gt[b, pred[b, :] >= 0.5].sum()
precision_denominator = (pred[b, :] >= 0.5).sum()
recall_denominator = (gt[b, :]).sum()
precision_ = tp_num / precision_denominator
recall_ = tp_num / recall_denominator
if precision_denominator == 0: # if no pred
precision_ = 1.
recall_ = 0.
f1_ = 0.
elif recall_denominator == 0: # if no GT
precision_ = 0.
recall_ = 1.
f1_ = 0.
elif (precision_ + recall_) <= 1e-10: # to avoid precision issues
precision_= 0.
recall_= 0.
f1_ = 0.
else:
f1_ = 2 * precision_ * recall_ / (precision_ + recall_)
precision[b] = precision_
recall[b] = recall_
f1[b] = f1_
# return precision, recall, f1
return precision, recall, f1
def acc_precision_recall_f1score(gt, pred):
"""
Compute acc, precision, recall, and f1
"""
# gt = gt.numpy()
# pred = pred.numpy()
acc = torch.zeros(gt.shape[0])
precision = torch.zeros(gt.shape[0])
recall = torch.zeros(gt.shape[0])
f1 = torch.zeros(gt.shape[0])
for b in range(gt.shape[0]):
tp_num = gt[b, pred[b, :] >= 0.5].sum()
precision_denominator = (pred[b, :] >= 0.5).sum()
recall_denominator = (gt[b, :]).sum()
tn_num = gt.shape[-1] - precision_denominator - recall_denominator + tp_num
acc_ = (tp_num + tn_num) / gt.shape[-1]
precision_ = tp_num / (precision_denominator + 1e-10)
recall_ = tp_num / (recall_denominator + 1e-10)
f1_ = 2 * precision_ * recall_ / (precision_ + recall_ + 1e-10)
acc[b] = acc_
precision[b] = precision_
recall[b] = recall_
# return precision, recall, f1
return acc, precision, recall, f1
def det_error_metric(pred, gt):
gt = gt.detach().cpu()
pred = pred.detach().cpu()
dist_matrix = torch.tensor(DIST_MATRIX)
false_positive_dist = torch.zeros(gt.shape[0])
false_negative_dist = torch.zeros(gt.shape[0])
for b in range(gt.shape[0]):
gt_columns = dist_matrix[:, gt[b, :]==1] if any(gt[b, :]==1) else dist_matrix
error_matrix = gt_columns[pred[b, :] >= 0.5, :] if any(pred[b, :] >= 0.5) else gt_columns
false_positive_dist_ = error_matrix.min(dim=1)[0].mean()
false_negative_dist_ = error_matrix.min(dim=0)[0].mean()
false_positive_dist[b] = false_positive_dist_
false_negative_dist[b] = false_negative_dist_
return false_positive_dist, false_negative_dist |