import numpy as np from skimage import measure from sklearn.metrics import auc from sklearn.metrics import precision_recall_curve from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_curve def calculate_max_f1(gt, scores): precision, recall, thresholds = precision_recall_curve(gt, scores) a = 2 * precision * recall b = precision + recall f1s = np.divide(a, b, out=np.zeros_like(a), where=b != 0) index = np.argmax(f1s) max_f1 = f1s[index] threshold = thresholds[index] return max_f1, threshold def metric_cal(scores, gt_list, gt_mask_list, cal_pro=False): # calculate image-level ROC AUC score img_scores = scores.reshape(scores.shape[0], -1).max(axis=1) gt_list = np.asarray(gt_list, dtype=int) fpr, tpr, _ = roc_curve(gt_list, img_scores) img_roc_auc = roc_auc_score(gt_list, img_scores) # print('INFO: image ROCAUC: %.3f' % (img_roc_auc)) img_f1, img_threshold = calculate_max_f1(gt_list, img_scores) gt_mask = np.asarray(gt_mask_list, dtype=int) pxl_f1, pxl_threshold = calculate_max_f1(gt_mask.flatten(), scores.flatten()) # calculate per-pixel level ROCAUC fpr, tpr, _ = roc_curve(gt_mask.flatten(), scores.flatten()) per_pixel_rocauc = roc_auc_score(gt_mask.flatten(), scores.flatten()) # calculate max-f1 region if cal_pro: # pro_auc_score = cal_pro_metric(gt_mask_list, scores, fpr_thresh=0.3) # calculate max-f1 region max_f1_region = calculate_max_f1_region(gt_mask_list, scores) else: # pro_auc_score = 0 # calculate max-f1 region max_f1_region = 0 result_dict = {'i_roc': img_roc_auc * 100, 'p_roc': per_pixel_rocauc * 100, 'i_f1': img_f1 * 100, 'i_thresh': img_threshold, 'p_f1': pxl_f1 * 100, 'p_thresh': pxl_threshold, 'r_f1': max_f1_region * 100} return result_dict def rescale(x): return (x - x.min()) / (x.max() - x.min()) def cal_pro_metric(labeled_imgs, score_imgs, fpr_thresh=0.3, max_steps=200): labeled_imgs = np.array(labeled_imgs) labeled_imgs[labeled_imgs <= 0.45] = 0 labeled_imgs[labeled_imgs > 0.45] = 1 labeled_imgs = labeled_imgs.astype(np.bool) max_th = score_imgs.max() min_th = score_imgs.min() delta = (max_th - min_th) / max_steps ious_mean = [] ious_std = [] pros_mean = [] pros_std = [] threds = [] fprs = [] binary_score_maps = np.zeros_like(score_imgs, dtype=bool) for step in range(max_steps): thred = max_th - step * delta # segmentation binary_score_maps[score_imgs <= thred] = 0 binary_score_maps[score_imgs > thred] = 1 pro = [] # per region overlap iou = [] # per image iou # pro: find each connected gt region, compute the overlapped pixels between the gt region and predicted region # iou: for each image, compute the ratio, i.e. intersection/union between the gt and predicted binary map for i in range(len(binary_score_maps)): # for i th image # pro (per region level) label_map = measure.label(labeled_imgs[i], connectivity=2) props = measure.regionprops(label_map) for prop in props: x_min, y_min, x_max, y_max = prop.bbox cropped_pred_label = binary_score_maps[i][x_min:x_max, y_min:y_max] # cropped_mask = masks[i][x_min:x_max, y_min:y_max] cropped_mask = prop.filled_image # corrected! intersection = np.logical_and(cropped_pred_label, cropped_mask).astype(np.float32).sum() pro.append(intersection / prop.area) # iou (per image level) intersection = np.logical_and(binary_score_maps[i], labeled_imgs[i]).astype(np.float32).sum() union = np.logical_or(binary_score_maps[i], labeled_imgs[i]).astype(np.float32).sum() if labeled_imgs[i].any() > 0: # when the gt have no anomaly pixels, skip it iou.append(intersection / union) # against steps and average metrics on the testing data ious_mean.append(np.array(iou).mean()) # print("per image mean iou:", np.array(iou).mean()) ious_std.append(np.array(iou).std()) pros_mean.append(np.array(pro).mean()) pros_std.append(np.array(pro).std()) # fpr for pro-auc masks_neg = ~labeled_imgs fpr = np.logical_and(masks_neg, binary_score_maps).sum() / masks_neg.sum() fprs.append(fpr) threds.append(thred) # as array threds = np.array(threds) pros_mean = np.array(pros_mean) pros_std = np.array(pros_std) fprs = np.array(fprs) # default 30% fpr vs pro, pro_auc idx = fprs <= fpr_thresh # find the indexs of fprs that is less than expect_fpr (default 0.3) fprs_selected = fprs[idx] fprs_selected = rescale(fprs_selected) # rescale fpr [0,0.3] -> [0, 1] pros_mean_selected = pros_mean[idx] pro_auc_score = auc(fprs_selected, pros_mean_selected) # print("pro auc ({}% FPR):".format(int(expect_fpr * 100)), pro_auc_score) return pro_auc_score def calculate_max_f1_region(labeled_imgs, score_imgs, pro_thresh=0.6, max_steps=200): labeled_imgs = np.array(labeled_imgs) # labeled_imgs[labeled_imgs <= 0.1] = 0 # labeled_imgs[labeled_imgs > 0.1] = 1 labeled_imgs = labeled_imgs.astype(bool) max_th = score_imgs.max() min_th = score_imgs.min() delta = (max_th - min_th) / max_steps f1_list = [] recall_list = [] precision_list = [] binary_score_maps = np.zeros_like(score_imgs, dtype=bool) for step in range(max_steps): thred = max_th - step * delta # segmentation binary_score_maps[score_imgs <= thred] = 0 binary_score_maps[score_imgs > thred] = 1 pro = [] # per region overlap predict_region_number = 0 gt_region_number = 0 # pro: find each connected gt region, compute the overlapped pixels between the gt region and predicted region # iou: for each image, compute the ratio, i.e. intersection/union between the gt and predicted binary map for i in range(len(binary_score_maps)): # for i th image # pro (per region level) label_map = measure.label(labeled_imgs[i], connectivity=2) props = measure.regionprops(label_map) score_map = measure.label(binary_score_maps[i], connectivity=2) score_props = measure.regionprops(score_map) predict_region_number += len(score_props) gt_region_number += len(props) # if len(score_props) == 0 or len(props) == 0: # pro.append(0) # continue for score_prop in score_props: x_min_0, y_min_0, x_max_0, y_max_0 = score_prop.bbox cur_pros = [0] for prop in props: x_min_1, y_min_1, x_max_1, y_max_1 = prop.bbox x_min = min(x_min_0, x_min_1) y_min = min(y_min_0, y_min_1) x_max = max(x_max_0, x_max_1) y_max = max(y_max_0, y_max_1) cropped_pred_label = binary_score_maps[i][x_min:x_max, y_min:y_max] cropped_gt_label = labeled_imgs[i][x_min:x_max, y_min:y_max] # cropped_mask = masks[i][x_min:x_max, y_min:y_max] # cropped_mask = prop.filled_image # corrected! intersection = np.logical_and(cropped_pred_label, cropped_gt_label).astype(np.float32).sum() union = np.logical_or(cropped_pred_label, cropped_gt_label).astype(np.float32).sum() cur_pros.append(intersection / union) pro.append(max(cur_pros)) pro = np.array(pro) if gt_region_number == 0 or predict_region_number == 0: print(f'gt_number: {gt_region_number}, pred_number: {predict_region_number}') recall = 0 precision = 0 f1 = 0 else: recall = np.array(pro >= pro_thresh).astype(np.float32).sum() / gt_region_number precision = np.array(pro >= pro_thresh).astype(np.float32).sum() / predict_region_number if recall == 0 or precision == 0: f1 = 0 else: f1 = 2 * recall * precision / (recall + precision) f1_list.append(f1) recall_list.append(recall) precision_list.append(precision) # as array f1_list = np.array(f1_list) max_f1 = f1_list.max() cor_recall = recall_list[f1_list.argmax()] cor_precision = precision_list[f1_list.argmax()] print(f'cor recall: {cor_recall}, cor precision: {cor_precision}') return max_f1