# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast from copy import deepcopy from multiprocessing.pool import Pool import numpy as np from nnunet.configuration import default_num_threads from nnunet.evaluation.evaluator import aggregate_scores from scipy.ndimage import label import SimpleITK as sitk from nnunet.utilities.sitk_stuff import copy_geometry from batchgenerators.utilities.file_and_folder_operations import * import shutil def load_remove_save(input_file: str, output_file: str, for_which_classes: list, minimum_valid_object_size: dict = None): # Only objects larger than minimum_valid_object_size will be removed. Keys in minimum_valid_object_size must # match entries in for_which_classes img_in = sitk.ReadImage(input_file) img_npy = sitk.GetArrayFromImage(img_in) volume_per_voxel = float(np.prod(img_in.GetSpacing(), dtype=np.float64)) image, largest_removed, kept_size = remove_all_but_the_largest_connected_component(img_npy, for_which_classes, volume_per_voxel, minimum_valid_object_size) # print(input_file, "kept:", kept_size) img_out_itk = sitk.GetImageFromArray(image) img_out_itk = copy_geometry(img_out_itk, img_in) sitk.WriteImage(img_out_itk, output_file) return largest_removed, kept_size def remove_all_but_the_largest_connected_component(image: np.ndarray, for_which_classes: list, volume_per_voxel: float, minimum_valid_object_size: dict = None): """ removes all but the largest connected component, individually for each class :param image: :param for_which_classes: can be None. Should be list of int. Can also be something like [(1, 2), 2, 4]. Here (1, 2) will be treated as a joint region, not individual classes (example LiTS here we can use (1, 2) to use all foreground classes together) :param minimum_valid_object_size: Only objects larger than minimum_valid_object_size will be removed. Keys in minimum_valid_object_size must match entries in for_which_classes :return: """ if for_which_classes is None: for_which_classes = np.unique(image) for_which_classes = for_which_classes[for_which_classes > 0] assert 0 not in for_which_classes, "cannot remove background" largest_removed = {} kept_size = {} for c in for_which_classes: if isinstance(c, (list, tuple)): c = tuple(c) # otherwise it cant be used as key in the dict mask = np.zeros_like(image, dtype=bool) for cl in c: mask[image == cl] = True else: mask = image == c # get labelmap and number of objects lmap, num_objects = label(mask.astype(int)) # collect object sizes object_sizes = {} for object_id in range(1, num_objects + 1): object_sizes[object_id] = (lmap == object_id).sum() * volume_per_voxel largest_removed[c] = None kept_size[c] = None if num_objects > 0: # we always keep the largest object. We could also consider removing the largest object if it is smaller # than minimum_valid_object_size in the future but we don't do that now. maximum_size = max(object_sizes.values()) kept_size[c] = maximum_size for object_id in range(1, num_objects + 1): # we only remove objects that are not the largest if object_sizes[object_id] != maximum_size: # we only remove objects that are smaller than minimum_valid_object_size remove = True if minimum_valid_object_size is not None: remove = object_sizes[object_id] < minimum_valid_object_size[c] if remove: image[(lmap == object_id) & mask] = 0 if largest_removed[c] is None: largest_removed[c] = object_sizes[object_id] else: largest_removed[c] = max(largest_removed[c], object_sizes[object_id]) return image, largest_removed, kept_size def load_postprocessing(json_file): ''' loads the relevant part of the pkl file that is needed for applying postprocessing :param pkl_file: :return: ''' a = load_json(json_file) if 'min_valid_object_sizes' in a.keys(): min_valid_object_sizes = ast.literal_eval(a['min_valid_object_sizes']) else: min_valid_object_sizes = None return a['for_which_classes'], min_valid_object_sizes def determine_postprocessing(base, gt_labels_folder, raw_subfolder_name="validation_raw", temp_folder="temp", final_subf_name="validation_final", processes=default_num_threads, dice_threshold=0, debug=False, advanced_postprocessing=False, pp_filename="postprocessing.json"): """ :param base: :param gt_labels_folder: subfolder of base with niftis of ground truth labels :param raw_subfolder_name: subfolder of base with niftis of predicted (non-postprocessed) segmentations :param temp_folder: used to store temporary data, will be deleted after we are done here undless debug=True :param final_subf_name: final results will be stored here (subfolder of base) :param processes: :param dice_threshold: only apply postprocessing if results is better than old_result+dice_threshold (can be used as eps) :param debug: if True then the temporary files will not be deleted :return: """ # lets see what classes are in the dataset classes = [int(i) for i in load_json(join(base, raw_subfolder_name, "summary.json"))['results']['mean'].keys() if int(i) != 0] folder_all_classes_as_fg = join(base, temp_folder + "_allClasses") folder_per_class = join(base, temp_folder + "_perClass") if isdir(folder_all_classes_as_fg): shutil.rmtree(folder_all_classes_as_fg) if isdir(folder_per_class): shutil.rmtree(folder_per_class) # multiprocessing rules p = Pool(processes) assert isfile(join(base, raw_subfolder_name, "summary.json")), "join(base, raw_subfolder_name) does not " \ "contain a summary.json" # these are all the files we will be dealing with fnames = subfiles(join(base, raw_subfolder_name), suffix=".nii.gz", join=False) # make output and temp dir maybe_mkdir_p(folder_all_classes_as_fg) maybe_mkdir_p(folder_per_class) maybe_mkdir_p(join(base, final_subf_name)) pp_results = {} pp_results['dc_per_class_raw'] = {} pp_results['dc_per_class_pp_all'] = {} # dice scores after treating all foreground classes as one pp_results['dc_per_class_pp_per_class'] = {} # dice scores after removing everything except larges cc # independently for each class after we already did dc_per_class_pp_all pp_results['for_which_classes'] = [] pp_results['min_valid_object_sizes'] = {} validation_result_raw = load_json(join(base, raw_subfolder_name, "summary.json"))['results'] pp_results['num_samples'] = len(validation_result_raw['all']) validation_result_raw = validation_result_raw['mean'] if advanced_postprocessing: # first treat all foreground classes as one and remove all but the largest foreground connected component results = [] for f in fnames: predicted_segmentation = join(base, raw_subfolder_name, f) # now remove all but the largest connected component for each class output_file = join(folder_all_classes_as_fg, f) results.append(p.starmap_async(load_remove_save, ((predicted_segmentation, output_file, (classes,)),))) results = [i.get() for i in results] # aggregate max_size_removed and min_size_kept max_size_removed = {} min_size_kept = {} for tmp in results: mx_rem, min_kept = tmp[0] for k in mx_rem: if mx_rem[k] is not None: if max_size_removed.get(k) is None: max_size_removed[k] = mx_rem[k] else: max_size_removed[k] = max(max_size_removed[k], mx_rem[k]) for k in min_kept: if min_kept[k] is not None: if min_size_kept.get(k) is None: min_size_kept[k] = min_kept[k] else: min_size_kept[k] = min(min_size_kept[k], min_kept[k]) print("foreground vs background, smallest valid object size was", min_size_kept[tuple(classes)]) print("removing only objects smaller than that...") else: min_size_kept = None # we need to rerun the step from above, now with the size constraint pred_gt_tuples = [] results = [] # first treat all foreground classes as one and remove all but the largest foreground connected component for f in fnames: predicted_segmentation = join(base, raw_subfolder_name, f) # now remove all but the largest connected component for each class output_file = join(folder_all_classes_as_fg, f) results.append( p.starmap_async(load_remove_save, ((predicted_segmentation, output_file, (classes,), min_size_kept),))) pred_gt_tuples.append([output_file, join(gt_labels_folder, f)]) _ = [i.get() for i in results] # evaluate postprocessed predictions _ = aggregate_scores(pred_gt_tuples, labels=classes, json_output_file=join(folder_all_classes_as_fg, "summary.json"), json_author="Fabian", num_threads=processes) # now we need to figure out if doing this improved the dice scores. We will implement that defensively in so far # that if a single class got worse as a result we won't do this. We can change this in the future but right now I # prefer to do it this way validation_result_PP_test = load_json(join(folder_all_classes_as_fg, "summary.json"))['results']['mean'] for c in classes: dc_raw = validation_result_raw[str(c)]['Dice'] dc_pp = validation_result_PP_test[str(c)]['Dice'] pp_results['dc_per_class_raw'][str(c)] = dc_raw pp_results['dc_per_class_pp_all'][str(c)] = dc_pp # true if new is better do_fg_cc = False comp = [pp_results['dc_per_class_pp_all'][str(cl)] > (pp_results['dc_per_class_raw'][str(cl)] + dice_threshold) for cl in classes] before = np.mean([pp_results['dc_per_class_raw'][str(cl)] for cl in classes]) after = np.mean([pp_results['dc_per_class_pp_all'][str(cl)] for cl in classes]) print("Foreground vs background") print("before:", before) print("after: ", after) if any(comp): # at least one class improved - yay! # now check if another got worse # true if new is worse any_worse = any( [pp_results['dc_per_class_pp_all'][str(cl)] < pp_results['dc_per_class_raw'][str(cl)] for cl in classes]) if not any_worse: pp_results['for_which_classes'].append(classes) if min_size_kept is not None: pp_results['min_valid_object_sizes'].update(deepcopy(min_size_kept)) do_fg_cc = True print("Removing all but the largest foreground region improved results!") print('for_which_classes', classes) print('min_valid_object_sizes', min_size_kept) else: # did not improve things - don't do it pass if len(classes) > 1: # now depending on whether we do remove all but the largest foreground connected component we define the source dir # for the next one to be the raw or the temp dir if do_fg_cc: source = folder_all_classes_as_fg else: source = join(base, raw_subfolder_name) if advanced_postprocessing: # now run this for each class separately results = [] for f in fnames: predicted_segmentation = join(source, f) output_file = join(folder_per_class, f) results.append(p.starmap_async(load_remove_save, ((predicted_segmentation, output_file, classes),))) results = [i.get() for i in results] # aggregate max_size_removed and min_size_kept max_size_removed = {} min_size_kept = {} for tmp in results: mx_rem, min_kept = tmp[0] for k in mx_rem: if mx_rem[k] is not None: if max_size_removed.get(k) is None: max_size_removed[k] = mx_rem[k] else: max_size_removed[k] = max(max_size_removed[k], mx_rem[k]) for k in min_kept: if min_kept[k] is not None: if min_size_kept.get(k) is None: min_size_kept[k] = min_kept[k] else: min_size_kept[k] = min(min_size_kept[k], min_kept[k]) print("classes treated separately, smallest valid object sizes are") print(min_size_kept) print("removing only objects smaller than that") else: min_size_kept = None # rerun with the size thresholds from above pred_gt_tuples = [] results = [] for f in fnames: predicted_segmentation = join(source, f) output_file = join(folder_per_class, f) results.append(p.starmap_async(load_remove_save, ((predicted_segmentation, output_file, classes, min_size_kept),))) pred_gt_tuples.append([output_file, join(gt_labels_folder, f)]) _ = [i.get() for i in results] # evaluate postprocessed predictions _ = aggregate_scores(pred_gt_tuples, labels=classes, json_output_file=join(folder_per_class, "summary.json"), json_author="Fabian", num_threads=processes) if do_fg_cc: old_res = deepcopy(validation_result_PP_test) else: old_res = validation_result_raw # these are the new dice scores validation_result_PP_test = load_json(join(folder_per_class, "summary.json"))['results']['mean'] for c in classes: dc_raw = old_res[str(c)]['Dice'] dc_pp = validation_result_PP_test[str(c)]['Dice'] pp_results['dc_per_class_pp_per_class'][str(c)] = dc_pp print(c) print("before:", dc_raw) print("after: ", dc_pp) if dc_pp > (dc_raw + dice_threshold): pp_results['for_which_classes'].append(int(c)) if min_size_kept is not None: pp_results['min_valid_object_sizes'].update({c: min_size_kept[c]}) print("Removing all but the largest region for class %d improved results!" % c) print('min_valid_object_sizes', min_size_kept) else: print("Only one class present, no need to do each class separately as this is covered in fg vs bg") if not advanced_postprocessing: pp_results['min_valid_object_sizes'] = None print("done") print("for which classes:") print(pp_results['for_which_classes']) print("min_object_sizes") print(pp_results['min_valid_object_sizes']) pp_results['validation_raw'] = raw_subfolder_name pp_results['validation_final'] = final_subf_name # now that we have a proper for_which_classes, apply that pred_gt_tuples = [] results = [] for f in fnames: predicted_segmentation = join(base, raw_subfolder_name, f) # now remove all but the largest connected component for each class output_file = join(base, final_subf_name, f) results.append(p.starmap_async(load_remove_save, ( (predicted_segmentation, output_file, pp_results['for_which_classes'], pp_results['min_valid_object_sizes']),))) pred_gt_tuples.append([output_file, join(gt_labels_folder, f)]) _ = [i.get() for i in results] # evaluate postprocessed predictions _ = aggregate_scores(pred_gt_tuples, labels=classes, json_output_file=join(base, final_subf_name, "summary.json"), json_author="Fabian", num_threads=processes) pp_results['min_valid_object_sizes'] = str(pp_results['min_valid_object_sizes']) save_json(pp_results, join(base, pp_filename)) # delete temp if not debug: shutil.rmtree(folder_per_class) shutil.rmtree(folder_all_classes_as_fg) p.close() p.join() print("done") def apply_postprocessing_to_folder(input_folder: str, output_folder: str, for_which_classes: list, min_valid_object_size:dict=None, num_processes=8): """ applies removing of all but the largest connected component to all niftis in a folder :param min_valid_object_size: :param min_valid_object_size: :param input_folder: :param output_folder: :param for_which_classes: :param num_processes: :return: """ maybe_mkdir_p(output_folder) p = Pool(num_processes) nii_files = subfiles(input_folder, suffix=".nii.gz", join=False) input_files = [join(input_folder, i) for i in nii_files] out_files = [join(output_folder, i) for i in nii_files] results = p.starmap_async(load_remove_save, zip(input_files, out_files, [for_which_classes] * len(input_files), [min_valid_object_size] * len(input_files))) res = results.get() p.close() p.join() if __name__ == "__main__": input_folder = "/media/fabian/DKFZ/predictions_Fabian/Liver_and_LiverTumor" output_folder = "/media/fabian/DKFZ/predictions_Fabian/Liver_and_LiverTumor_postprocessed" for_which_classes = [(1, 2), ] apply_postprocessing_to_folder(input_folder, output_folder, for_which_classes)