|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import shutil |
|
from copy import deepcopy |
|
|
|
from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax |
|
from batchgenerators.utilities.file_and_folder_operations import * |
|
import numpy as np |
|
from multiprocessing import Pool |
|
from nnunet.postprocessing.connected_components import apply_postprocessing_to_folder, load_postprocessing |
|
|
|
|
|
def merge_files(files, properties_files, out_file, override, store_npz): |
|
if override or not isfile(out_file): |
|
softmax = [np.load(f)['softmax'][None] for f in files] |
|
softmax = np.vstack(softmax) |
|
softmax = np.mean(softmax, 0) |
|
props = [load_pickle(f) for f in properties_files] |
|
|
|
reg_class_orders = [p['regions_class_order'] if 'regions_class_order' in p.keys() else None |
|
for p in props] |
|
|
|
if not all([i is None for i in reg_class_orders]): |
|
|
|
tmp = reg_class_orders[0] |
|
for r in reg_class_orders[1:]: |
|
assert tmp == r, 'If merging files with regions_class_order, the regions_class_orders of all ' \ |
|
'files must be the same. regions_class_order: %s, \n files: %s' % \ |
|
(str(reg_class_orders), str(files)) |
|
regions_class_order = tmp |
|
else: |
|
regions_class_order = None |
|
|
|
|
|
|
|
save_segmentation_nifti_from_softmax(softmax, out_file, props[0], 3, regions_class_order, None, None, |
|
force_separate_z=None) |
|
if store_npz: |
|
np.savez_compressed(out_file[:-7] + ".npz", softmax=softmax) |
|
save_pickle(props, out_file[:-7] + ".pkl") |
|
|
|
|
|
def merge(folders, output_folder, threads, override=True, postprocessing_file=None, store_npz=False): |
|
maybe_mkdir_p(output_folder) |
|
|
|
if postprocessing_file is not None: |
|
output_folder_orig = deepcopy(output_folder) |
|
output_folder = join(output_folder, 'not_postprocessed') |
|
maybe_mkdir_p(output_folder) |
|
else: |
|
output_folder_orig = None |
|
|
|
patient_ids = [subfiles(i, suffix=".npz", join=False) for i in folders] |
|
patient_ids = [i for j in patient_ids for i in j] |
|
patient_ids = [i[:-4] for i in patient_ids] |
|
patient_ids = np.unique(patient_ids) |
|
|
|
for f in folders: |
|
assert all([isfile(join(f, i + ".npz")) for i in patient_ids]), "Not all patient npz are available in " \ |
|
"all folders" |
|
assert all([isfile(join(f, i + ".pkl")) for i in patient_ids]), "Not all patient pkl are available in " \ |
|
"all folders" |
|
|
|
files = [] |
|
property_files = [] |
|
out_files = [] |
|
for p in patient_ids: |
|
files.append([join(f, p + ".npz") for f in folders]) |
|
property_files.append([join(f, p + ".pkl") for f in folders]) |
|
out_files.append(join(output_folder, p + ".nii.gz")) |
|
|
|
p = Pool(threads) |
|
p.starmap(merge_files, zip(files, property_files, out_files, [override] * len(out_files), [store_npz] * len(out_files))) |
|
p.close() |
|
p.join() |
|
|
|
if postprocessing_file is not None: |
|
for_which_classes, min_valid_obj_size = load_postprocessing(postprocessing_file) |
|
print('Postprocessing...') |
|
apply_postprocessing_to_folder(output_folder, output_folder_orig, |
|
for_which_classes, min_valid_obj_size, threads) |
|
shutil.copy(postprocessing_file, output_folder_orig) |
|
|
|
|
|
def main(): |
|
import argparse |
|
parser = argparse.ArgumentParser(description="This script will merge predictions (that were prdicted with the " |
|
"-npz option!). You need to specify a postprocessing file so that " |
|
"we know here what postprocessing must be applied. Failing to do so " |
|
"will disable postprocessing") |
|
parser.add_argument('-f', '--folders', nargs='+', help="list of folders to merge. All folders must contain npz " |
|
"files", required=True) |
|
parser.add_argument('-o', '--output_folder', help="where to save the results", required=True, type=str) |
|
parser.add_argument('-t', '--threads', help="number of threads used to saving niftis", required=False, default=2, |
|
type=int) |
|
parser.add_argument('-pp', '--postprocessing_file', help="path to the file where the postprocessing configuration " |
|
"is stored. If this is not provided then no postprocessing " |
|
"will be made. It is strongly recommended to provide the " |
|
"postprocessing file!", |
|
required=False, type=str, default=None) |
|
parser.add_argument('--npz', action="store_true", required=False, help="stores npz and pkl") |
|
|
|
args = parser.parse_args() |
|
|
|
folders = args.folders |
|
threads = args.threads |
|
output_folder = args.output_folder |
|
pp_file = args.postprocessing_file |
|
npz = args.npz |
|
|
|
merge(folders, output_folder, threads, override=True, postprocessing_file=pp_file, store_npz=npz) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|