ho11laqe's picture
init
ecf08bc
raw
history blame
6.3 kB
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from copy import deepcopy
from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax
from batchgenerators.utilities.file_and_folder_operations import *
import numpy as np
from multiprocessing import Pool
from nnunet.postprocessing.connected_components import apply_postprocessing_to_folder, load_postprocessing
def merge_files(files, properties_files, out_file, override, store_npz):
if override or not isfile(out_file):
softmax = [np.load(f)['softmax'][None] for f in files]
softmax = np.vstack(softmax)
softmax = np.mean(softmax, 0)
props = [load_pickle(f) for f in properties_files]
reg_class_orders = [p['regions_class_order'] if 'regions_class_order' in p.keys() else None
for p in props]
if not all([i is None for i in reg_class_orders]):
# if reg_class_orders are not None then they must be the same in all pkls
tmp = reg_class_orders[0]
for r in reg_class_orders[1:]:
assert tmp == r, 'If merging files with regions_class_order, the regions_class_orders of all ' \
'files must be the same. regions_class_order: %s, \n files: %s' % \
(str(reg_class_orders), str(files))
regions_class_order = tmp
else:
regions_class_order = None
# Softmax probabilities are already at target spacing so this will not do any resampling (resampling parameters
# don't matter here)
save_segmentation_nifti_from_softmax(softmax, out_file, props[0], 3, regions_class_order, None, None,
force_separate_z=None)
if store_npz:
np.savez_compressed(out_file[:-7] + ".npz", softmax=softmax)
save_pickle(props, out_file[:-7] + ".pkl")
def merge(folders, output_folder, threads, override=True, postprocessing_file=None, store_npz=False):
maybe_mkdir_p(output_folder)
if postprocessing_file is not None:
output_folder_orig = deepcopy(output_folder)
output_folder = join(output_folder, 'not_postprocessed')
maybe_mkdir_p(output_folder)
else:
output_folder_orig = None
patient_ids = [subfiles(i, suffix=".npz", join=False) for i in folders]
patient_ids = [i for j in patient_ids for i in j]
patient_ids = [i[:-4] for i in patient_ids]
patient_ids = np.unique(patient_ids)
for f in folders:
assert all([isfile(join(f, i + ".npz")) for i in patient_ids]), "Not all patient npz are available in " \
"all folders"
assert all([isfile(join(f, i + ".pkl")) for i in patient_ids]), "Not all patient pkl are available in " \
"all folders"
files = []
property_files = []
out_files = []
for p in patient_ids:
files.append([join(f, p + ".npz") for f in folders])
property_files.append([join(f, p + ".pkl") for f in folders])
out_files.append(join(output_folder, p + ".nii.gz"))
p = Pool(threads)
p.starmap(merge_files, zip(files, property_files, out_files, [override] * len(out_files), [store_npz] * len(out_files)))
p.close()
p.join()
if postprocessing_file is not None:
for_which_classes, min_valid_obj_size = load_postprocessing(postprocessing_file)
print('Postprocessing...')
apply_postprocessing_to_folder(output_folder, output_folder_orig,
for_which_classes, min_valid_obj_size, threads)
shutil.copy(postprocessing_file, output_folder_orig)
def main():
import argparse
parser = argparse.ArgumentParser(description="This script will merge predictions (that were prdicted with the "
"-npz option!). You need to specify a postprocessing file so that "
"we know here what postprocessing must be applied. Failing to do so "
"will disable postprocessing")
parser.add_argument('-f', '--folders', nargs='+', help="list of folders to merge. All folders must contain npz "
"files", required=True)
parser.add_argument('-o', '--output_folder', help="where to save the results", required=True, type=str)
parser.add_argument('-t', '--threads', help="number of threads used to saving niftis", required=False, default=2,
type=int)
parser.add_argument('-pp', '--postprocessing_file', help="path to the file where the postprocessing configuration "
"is stored. If this is not provided then no postprocessing "
"will be made. It is strongly recommended to provide the "
"postprocessing file!",
required=False, type=str, default=None)
parser.add_argument('--npz', action="store_true", required=False, help="stores npz and pkl")
args = parser.parse_args()
folders = args.folders
threads = args.threads
output_folder = args.output_folder
pp_file = args.postprocessing_file
npz = args.npz
merge(folders, output_folder, threads, override=True, postprocessing_file=pp_file, store_npz=npz)
if __name__ == "__main__":
main()