import numpy as np import torch import monai.metrics as metrics from common.constants import DIST_MATRIX_PATH DIST_MATRIX = np.load(DIST_MATRIX_PATH) def metric(mask, pred, back=True): iou = metrics.compute_meaniou(pred, mask, back, False) iou = iou.mean() return iou def precision_recall_f1score(gt, pred): """ Compute precision, recall, and f1 """ # gt = gt.numpy() # pred = pred.numpy() precision = torch.zeros(gt.shape[0]) recall = torch.zeros(gt.shape[0]) f1 = torch.zeros(gt.shape[0]) for b in range(gt.shape[0]): tp_num = gt[b, pred[b, :] >= 0.5].sum() precision_denominator = (pred[b, :] >= 0.5).sum() recall_denominator = (gt[b, :]).sum() precision_ = tp_num / precision_denominator recall_ = tp_num / recall_denominator if precision_denominator == 0: # if no pred precision_ = 1. recall_ = 0. f1_ = 0. elif recall_denominator == 0: # if no GT precision_ = 0. recall_ = 1. f1_ = 0. elif (precision_ + recall_) <= 1e-10: # to avoid precision issues precision_= 0. recall_= 0. f1_ = 0. else: f1_ = 2 * precision_ * recall_ / (precision_ + recall_) precision[b] = precision_ recall[b] = recall_ f1[b] = f1_ # return precision, recall, f1 return precision, recall, f1 def acc_precision_recall_f1score(gt, pred): """ Compute acc, precision, recall, and f1 """ # gt = gt.numpy() # pred = pred.numpy() acc = torch.zeros(gt.shape[0]) precision = torch.zeros(gt.shape[0]) recall = torch.zeros(gt.shape[0]) f1 = torch.zeros(gt.shape[0]) for b in range(gt.shape[0]): tp_num = gt[b, pred[b, :] >= 0.5].sum() precision_denominator = (pred[b, :] >= 0.5).sum() recall_denominator = (gt[b, :]).sum() tn_num = gt.shape[-1] - precision_denominator - recall_denominator + tp_num acc_ = (tp_num + tn_num) / gt.shape[-1] precision_ = tp_num / (precision_denominator + 1e-10) recall_ = tp_num / (recall_denominator + 1e-10) f1_ = 2 * precision_ * recall_ / (precision_ + recall_ + 1e-10) acc[b] = acc_ precision[b] = precision_ recall[b] = recall_ # return precision, recall, f1 return acc, precision, recall, f1 def det_error_metric(pred, gt): gt = gt.detach().cpu() pred = pred.detach().cpu() dist_matrix = torch.tensor(DIST_MATRIX) false_positive_dist = torch.zeros(gt.shape[0]) false_negative_dist = torch.zeros(gt.shape[0]) for b in range(gt.shape[0]): gt_columns = dist_matrix[:, gt[b, :]==1] if any(gt[b, :]==1) else dist_matrix error_matrix = gt_columns[pred[b, :] >= 0.5, :] if any(pred[b, :] >= 0.5) else gt_columns false_positive_dist_ = error_matrix.min(dim=1)[0].mean() false_negative_dist_ = error_matrix.min(dim=0)[0].mean() false_positive_dist[b] = false_positive_dist_ false_negative_dist[b] = false_negative_dist_ return false_positive_dist, false_negative_dist