|
from copy import deepcopy |
|
from multiprocessing.pool import Pool |
|
|
|
from batchgenerators.utilities.file_and_folder_operations import * |
|
from medpy import metric |
|
import SimpleITK as sitk |
|
import numpy as np |
|
from nnunet.configuration import default_num_threads |
|
from nnunet.postprocessing.consolidate_postprocessing import collect_cv_niftis |
|
|
|
|
|
def get_brats_regions(): |
|
""" |
|
this is only valid for the brats data in here where the labels are 1, 2, and 3. The original brats data have a |
|
different labeling convention! |
|
:return: |
|
""" |
|
regions = { |
|
"whole tumor": (1, 2, 3), |
|
"tumor core": (2, 3), |
|
"enhancing tumor": (3,) |
|
} |
|
return regions |
|
|
|
|
|
def get_KiTS_regions(): |
|
regions = { |
|
"kidney incl tumor": (1, 2), |
|
"tumor": (2,) |
|
} |
|
return regions |
|
|
|
|
|
def create_region_from_mask(mask, join_labels: tuple): |
|
mask_new = np.zeros_like(mask, dtype=np.uint8) |
|
for l in join_labels: |
|
mask_new[mask == l] = 1 |
|
return mask_new |
|
|
|
|
|
def evaluate_case(file_pred: str, file_gt: str, regions): |
|
image_gt = sitk.GetArrayFromImage(sitk.ReadImage(file_gt)) |
|
image_pred = sitk.GetArrayFromImage(sitk.ReadImage(file_pred)) |
|
results = [] |
|
for r in regions: |
|
mask_pred = create_region_from_mask(image_pred, r) |
|
mask_gt = create_region_from_mask(image_gt, r) |
|
dc = np.nan if np.sum(mask_gt) == 0 and np.sum(mask_pred) == 0 else metric.dc(mask_pred, mask_gt) |
|
results.append(dc) |
|
return results |
|
|
|
|
|
def evaluate_regions(folder_predicted: str, folder_gt: str, regions: dict, processes=default_num_threads): |
|
region_names = list(regions.keys()) |
|
files_in_pred = subfiles(folder_predicted, suffix='.nii.gz', join=False) |
|
files_in_gt = subfiles(folder_gt, suffix='.nii.gz', join=False) |
|
have_no_gt = [i for i in files_in_pred if i not in files_in_gt] |
|
assert len(have_no_gt) == 0, "Some files in folder_predicted have not ground truth in folder_gt" |
|
have_no_pred = [i for i in files_in_gt if i not in files_in_pred] |
|
if len(have_no_pred) > 0: |
|
print("WARNING! Some files in folder_gt were not predicted (not present in folder_predicted)!") |
|
|
|
files_in_gt.sort() |
|
files_in_pred.sort() |
|
|
|
|
|
full_filenames_gt = [join(folder_gt, i) for i in files_in_pred] |
|
full_filenames_pred = [join(folder_predicted, i) for i in files_in_pred] |
|
|
|
p = Pool(processes) |
|
res = p.starmap(evaluate_case, zip(full_filenames_pred, full_filenames_gt, [list(regions.values())] * len(files_in_gt))) |
|
p.close() |
|
p.join() |
|
|
|
all_results = {r: [] for r in region_names} |
|
with open(join(folder_predicted, 'summary.csv'), 'w') as f: |
|
f.write("casename") |
|
for r in region_names: |
|
f.write(",%s" % r) |
|
f.write("\n") |
|
for i in range(len(files_in_pred)): |
|
f.write(files_in_pred[i][:-7]) |
|
result_here = res[i] |
|
for k, r in enumerate(region_names): |
|
dc = result_here[k] |
|
f.write(",%02.4f" % dc) |
|
all_results[r].append(dc) |
|
f.write("\n") |
|
|
|
f.write('mean') |
|
for r in region_names: |
|
f.write(",%02.4f" % np.nanmean(all_results[r])) |
|
f.write("\n") |
|
f.write('median') |
|
for r in region_names: |
|
f.write(",%02.4f" % np.nanmedian(all_results[r])) |
|
f.write("\n") |
|
|
|
f.write('mean (nan is 1)') |
|
for r in region_names: |
|
tmp = np.array(all_results[r]) |
|
tmp[np.isnan(tmp)] = 1 |
|
f.write(",%02.4f" % np.mean(tmp)) |
|
f.write("\n") |
|
f.write('median (nan is 1)') |
|
for r in region_names: |
|
tmp = np.array(all_results[r]) |
|
tmp[np.isnan(tmp)] = 1 |
|
f.write(",%02.4f" % np.median(tmp)) |
|
f.write("\n") |
|
|
|
|
|
if __name__ == '__main__': |
|
collect_cv_niftis('./', './cv_niftis') |
|
evaluate_regions('./cv_niftis/', './gt_niftis/', get_brats_regions()) |
|
|