from copy import deepcopy from multiprocessing.pool import Pool from batchgenerators.utilities.file_and_folder_operations import * from medpy import metric import SimpleITK as sitk import numpy as np from nnunet.configuration import default_num_threads from nnunet.postprocessing.consolidate_postprocessing import collect_cv_niftis def get_brats_regions(): """ this is only valid for the brats data in here where the labels are 1, 2, and 3. The original brats data have a different labeling convention! :return: """ regions = { "whole tumor": (1, 2, 3), "tumor core": (2, 3), "enhancing tumor": (3,) } return regions def get_KiTS_regions(): regions = { "kidney incl tumor": (1, 2), "tumor": (2,) } return regions def create_region_from_mask(mask, join_labels: tuple): mask_new = np.zeros_like(mask, dtype=np.uint8) for l in join_labels: mask_new[mask == l] = 1 return mask_new def evaluate_case(file_pred: str, file_gt: str, regions): image_gt = sitk.GetArrayFromImage(sitk.ReadImage(file_gt)) image_pred = sitk.GetArrayFromImage(sitk.ReadImage(file_pred)) results = [] for r in regions: mask_pred = create_region_from_mask(image_pred, r) mask_gt = create_region_from_mask(image_gt, r) dc = np.nan if np.sum(mask_gt) == 0 and np.sum(mask_pred) == 0 else metric.dc(mask_pred, mask_gt) results.append(dc) return results def evaluate_regions(folder_predicted: str, folder_gt: str, regions: dict, processes=default_num_threads): region_names = list(regions.keys()) files_in_pred = subfiles(folder_predicted, suffix='.nii.gz', join=False) files_in_gt = subfiles(folder_gt, suffix='.nii.gz', join=False) have_no_gt = [i for i in files_in_pred if i not in files_in_gt] assert len(have_no_gt) == 0, "Some files in folder_predicted have not ground truth in folder_gt" have_no_pred = [i for i in files_in_gt if i not in files_in_pred] if len(have_no_pred) > 0: print("WARNING! Some files in folder_gt were not predicted (not present in folder_predicted)!") files_in_gt.sort() files_in_pred.sort() # run for all cases full_filenames_gt = [join(folder_gt, i) for i in files_in_pred] full_filenames_pred = [join(folder_predicted, i) for i in files_in_pred] p = Pool(processes) res = p.starmap(evaluate_case, zip(full_filenames_pred, full_filenames_gt, [list(regions.values())] * len(files_in_gt))) p.close() p.join() all_results = {r: [] for r in region_names} with open(join(folder_predicted, 'summary.csv'), 'w') as f: f.write("casename") for r in region_names: f.write(",%s" % r) f.write("\n") for i in range(len(files_in_pred)): f.write(files_in_pred[i][:-7]) result_here = res[i] for k, r in enumerate(region_names): dc = result_here[k] f.write(",%02.4f" % dc) all_results[r].append(dc) f.write("\n") f.write('mean') for r in region_names: f.write(",%02.4f" % np.nanmean(all_results[r])) f.write("\n") f.write('median') for r in region_names: f.write(",%02.4f" % np.nanmedian(all_results[r])) f.write("\n") f.write('mean (nan is 1)') for r in region_names: tmp = np.array(all_results[r]) tmp[np.isnan(tmp)] = 1 f.write(",%02.4f" % np.mean(tmp)) f.write("\n") f.write('median (nan is 1)') for r in region_names: tmp = np.array(all_results[r]) tmp[np.isnan(tmp)] = 1 f.write(",%02.4f" % np.median(tmp)) f.write("\n") if __name__ == '__main__': collect_cv_niftis('./', './cv_niftis') evaluate_regions('./cv_niftis/', './gt_niftis/', get_brats_regions())