# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from copy import deepcopy from typing import Tuple, Union, List import numpy as np from batchgenerators.augmentations.utils import resize_segmentation from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax, save_segmentation_nifti from batchgenerators.utilities.file_and_folder_operations import * from multiprocessing import Process, Queue import torch import SimpleITK as sitk import shutil from multiprocessing import Pool from nnunet.postprocessing.connected_components import load_remove_save, load_postprocessing from nnunet.training.model_restore import load_model_and_checkpoint_files from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer from nnunet.utilities.one_hot_encoding import to_one_hot def preprocess_save_to_queue(preprocess_fn, q, list_of_lists, output_files, segs_from_prev_stage, classes, transpose_forward): # suppress output # sys.stdout = open(os.devnull, 'w') errors_in = [] for i, l in enumerate(list_of_lists): try: output_file = output_files[i] print("preprocessing", output_file) d, _, dct = preprocess_fn(l) dct['classes'] = [[0]+ cl for cl in classes] # print(output_file, dct) if segs_from_prev_stage[i] is not None: assert isfile(segs_from_prev_stage[i]) and segs_from_prev_stage[i].endswith( ".nii.gz"), "segs_from_prev_stage" \ " must point to a " \ "segmentation file" seg_prev = sitk.GetArrayFromImage(sitk.ReadImage(segs_from_prev_stage[i])) # check to see if shapes match img = sitk.GetArrayFromImage(sitk.ReadImage(l[0])) assert all([i == j for i, j in zip(seg_prev.shape, img.shape)]), "image and segmentation from previous " \ "stage don't have the same pixel array " \ "shape! image: %s, seg_prev: %s" % \ (l[0], segs_from_prev_stage[i]) seg_prev = seg_prev.transpose(transpose_forward) seg_reshaped = resize_segmentation(seg_prev, d.shape[1:], order=1) seg_reshaped = to_one_hot(seg_reshaped, classes) d = np.vstack((d, seg_reshaped)).astype(np.float32) """There is a problem with python process communication that prevents us from communicating obejcts larger than 2 GB between processes (basically when the length of the pickle string that will be sent is communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either filename or np.ndarray and will handle this automatically""" print(d.shape) if np.prod(d.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save, 4 because float32 is 4 bytes print( "This output is too large for python process-process communication. " "Saving output temporarily to disk") np.save(output_file[:-7] + ".npy", d) d = output_file[:-7] + ".npy" q.put((output_file, (d, dct))) except KeyboardInterrupt: raise KeyboardInterrupt except Exception as e: print("error in", l) print(e) q.put("end") if len(errors_in) > 0: print("There were some errors in the following cases:", errors_in) print("These cases were ignored.") else: print("This worker has ended successfully, no errors to report") # restore output # sys.stdout = sys.__stdout__ def preprocess_multithreaded(trainer, list_of_lists, output_files, num_processes=2, segs_from_prev_stage=None): if segs_from_prev_stage is None: segs_from_prev_stage = [None] * len(list_of_lists) num_processes = min(len(list_of_lists), num_processes) classes = [list(range(1, num_classes)) for num_classes in trainer.num_classes] assert isinstance(trainer, nnUNetTrainer) q = Queue(1) processes = [] for i in range(num_processes): """ pr = preprocess_save_to_queue(trainer.preprocess_patient, q, list_of_lists[i::num_processes], output_files[i::num_processes], segs_from_prev_stage[i::num_processes], classes, trainer.plans['transpose_forward']) """ pr = Process(target=preprocess_save_to_queue, args=(trainer.preprocess_patient, q, list_of_lists[i::num_processes], output_files[i::num_processes], segs_from_prev_stage[i::num_processes], classes, trainer.plans['transpose_forward'])) pr.start() processes.append(pr) try: end_ctr = 0 while end_ctr != num_processes: item = q.get() if item == "end": end_ctr += 1 continue else: yield item finally: for p in processes: if p.is_alive(): p.terminate() # this should not happen but better safe than sorry right p.join() q.close() def predict_cases(model, list_of_lists, output_filenames, folds, save_npz, num_threads_preprocessing, num_threads_nifti_save, segs_from_prev_stage=None, do_tta=True, mixed_precision=True, overwrite_existing=False, all_in_gpu=False, step_size=0.5, checkpoint_name="model_final_checkpoint", segmentation_export_kwargs: dict = None, disable_postprocessing: bool = False): """ :param segmentation_export_kwargs: :param model: folder where the model is saved, must contain fold_x subfolders :param list_of_lists: [[case0_0000.nii.gz, case0_0001.nii.gz], [case1_0000.nii.gz, case1_0001.nii.gz], ...] :param output_filenames: [output_file_case0.nii.gz, output_file_case1.nii.gz, ...] :param folds: default: (0, 1, 2, 3, 4) (but can also be 'all' or a subset of the five folds, for example use (0, ) for using only fold_0 :param save_npz: default: False :param num_threads_preprocessing: :param num_threads_nifti_save: :param segs_from_prev_stage: :param do_tta: default: True, can be set to False for a 8x speedup at the cost of a reduced segmentation quality :param overwrite_existing: default: True :param mixed_precision: if None then we take no action. If True/False we overwrite what the model has in its init :return: """ assert len(list_of_lists) == len(output_filenames) if segs_from_prev_stage is not None: assert len(segs_from_prev_stage) == len(output_filenames) pool = Pool(num_threads_nifti_save) results = [] cleaned_output_files = [] for o in output_filenames: dr, f = os.path.split(o) if len(dr) > 0: maybe_mkdir_p(dr) if not f.endswith(".nii.gz"): f, _ = os.path.splitext(f) f = f + ".nii.gz" cleaned_output_files.append(join(dr, f)) if not overwrite_existing: print("number of cases:", len(list_of_lists)) # if save_npz=True then we should also check for missing npz files not_done_idx = [i for i, j in enumerate(cleaned_output_files) if (not isfile(j)) or (save_npz and not isfile(j[:-7] + '.npz'))] cleaned_output_files = [cleaned_output_files[i] for i in not_done_idx] list_of_lists = [list_of_lists[i] for i in not_done_idx] if segs_from_prev_stage is not None: segs_from_prev_stage = [segs_from_prev_stage[i] for i in not_done_idx] print("number of cases that still need to be predicted:", len(cleaned_output_files)) print("emptying cuda cache") torch.cuda.empty_cache() print("loading parameters for folds,", folds) trainer, params = load_model_and_checkpoint_files(model, folds, mixed_precision=mixed_precision, checkpoint_name=checkpoint_name) if segmentation_export_kwargs is None: if 'segmentation_export_params' in trainer.plans.keys(): force_separate_z = trainer.plans['segmentation_export_params']['force_separate_z'] interpolation_order = trainer.plans['segmentation_export_params']['interpolation_order'] interpolation_order_z = trainer.plans['segmentation_export_params']['interpolation_order_z'] else: force_separate_z = None interpolation_order = 1 interpolation_order_z = 0 else: force_separate_z = segmentation_export_kwargs['force_separate_z'] interpolation_order = segmentation_export_kwargs['interpolation_order'] interpolation_order_z = segmentation_export_kwargs['interpolation_order_z'] print("starting preprocessing generator") preprocessing = preprocess_multithreaded(trainer, list_of_lists, cleaned_output_files, num_threads_preprocessing, segs_from_prev_stage) print("starting prediction...") all_output_files = [] for preprocessed in preprocessing: output_filename, (d, dct) = preprocessed all_output_files.append(output_filename) if isinstance(d, str): data = np.load(d) os.remove(d) d = data print("predicting", output_filename) trainer.load_checkpoint_ram(params[0], False) softmax = trainer.predict_preprocessed_data_return_seg_and_softmax( d, do_mirroring=do_tta, mirror_axes=trainer.data_aug_params['mirror_axes'], use_sliding_window=True, step_size=step_size, use_gaussian=True, all_in_gpu=all_in_gpu, mixed_precision=mixed_precision)[1] for p in params[1:]: trainer.load_checkpoint_ram(p, False) softmax += trainer.predict_preprocessed_data_return_seg_and_softmax( d, do_mirroring=do_tta, mirror_axes=trainer.data_aug_params['mirror_axes'], use_sliding_window=True, step_size=step_size, use_gaussian=True, all_in_gpu=all_in_gpu, mixed_precision=mixed_precision)[1] if len(params) > 1: softmax /= len(params) transpose_forward = trainer.plans.get('transpose_forward') if transpose_forward is not None: transpose_backward = trainer.plans.get('transpose_backward') softmax = softmax.transpose([0] + [i + 1 for i in transpose_backward]) if save_npz: npz_file = output_filename[:-7] + ".npz" else: npz_file = None if hasattr(trainer, 'regions_class_order'): region_class_order = trainer.regions_class_order else: region_class_order = None """There is a problem with python process communication that prevents us from communicating obejcts larger than 2 GB between processes (basically when the length of the pickle string that will be sent is communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either filename or np.ndarray and will handle this automatically""" bytes_per_voxel = 4 if all_in_gpu: bytes_per_voxel = 2 # if all_in_gpu then the return value is half (float16) if np.prod(softmax.shape) > (2e9 / bytes_per_voxel * 0.85): # * 0.85 just to be save print( "This output is too large for python process-process communication. Saving output temporarily to disk") np.save(output_filename[:-7] + ".npy", softmax) softmax = output_filename[:-7] + ".npy" """ save_segmentation_nifti_from_softmax(softmax, output_filename, dct, interpolation_order, region_class_order, None, None, npz_file, None, force_separate_z, interpolation_order_z) """ results.append(pool.starmap_async(save_segmentation_nifti_from_softmax, ((softmax, output_filename, dct, interpolation_order, region_class_order, None, None, npz_file, None, force_separate_z, interpolation_order_z),) )) print("inference done. Now waiting for the segmentation export to finish...") _ = [i.get() for i in results] # now apply postprocessing # first load the postprocessing properties if they are present. Else raise a well visible warning if not disable_postprocessing: results = [] pp_file = join(model, "postprocessing.json") if isfile(pp_file): print("postprocessing...") shutil.copy(pp_file, os.path.abspath(os.path.dirname(output_filenames[0]))) # for_which_classes stores for which of the classes everything but the largest connected component needs to be # removed for_which_classes, min_valid_obj_size = load_postprocessing(pp_file) results.append(pool.starmap_async(load_remove_save, zip(output_filenames, output_filenames, [for_which_classes] * len(output_filenames), [min_valid_obj_size] * len(output_filenames)))) _ = [i.get() for i in results] else: print("WARNING! Cannot run postprocessing because the postprocessing file is missing. Make sure to run " "consolidate_folds in the output folder of the model first!\nThe folder you need to run this in is " "%s" % model) pool.close() pool.join() def predict_cases_fast(model, list_of_lists, output_filenames, folds, num_threads_preprocessing, num_threads_nifti_save, segs_from_prev_stage=None, do_tta=True, mixed_precision=True, overwrite_existing=False, all_in_gpu=False, step_size=0.5, checkpoint_name="model_final_checkpoint", segmentation_export_kwargs: dict = None, disable_postprocessing: bool = False): assert len(list_of_lists) == len(output_filenames) if segs_from_prev_stage is not None: assert len(segs_from_prev_stage) == len(output_filenames) pool = Pool(num_threads_nifti_save) results = [] cleaned_output_files = [] for o in output_filenames: dr, f = os.path.split(o) if len(dr) > 0: maybe_mkdir_p(dr) if not f.endswith(".nii.gz"): f, _ = os.path.splitext(f) f = f + ".nii.gz" cleaned_output_files.append(join(dr, f)) if not overwrite_existing: print("number of cases:", len(list_of_lists)) not_done_idx = [i for i, j in enumerate(cleaned_output_files) if not isfile(j)] cleaned_output_files = [cleaned_output_files[i] for i in not_done_idx] list_of_lists = [list_of_lists[i] for i in not_done_idx] if segs_from_prev_stage is not None: segs_from_prev_stage = [segs_from_prev_stage[i] for i in not_done_idx] print("number of cases that still need to be predicted:", len(cleaned_output_files)) print("emptying cuda cache") torch.cuda.empty_cache() print("loading parameters for folds,", folds) trainer, params = load_model_and_checkpoint_files(model, folds, mixed_precision=mixed_precision, checkpoint_name=checkpoint_name) if segmentation_export_kwargs is None: if 'segmentation_export_params' in trainer.plans.keys(): force_separate_z = trainer.plans['segmentation_export_params']['force_separate_z'] interpolation_order = trainer.plans['segmentation_export_params']['interpolation_order'] interpolation_order_z = trainer.plans['segmentation_export_params']['interpolation_order_z'] else: force_separate_z = None interpolation_order = 1 interpolation_order_z = 0 else: force_separate_z = segmentation_export_kwargs['force_separate_z'] interpolation_order = segmentation_export_kwargs['interpolation_order'] interpolation_order_z = segmentation_export_kwargs['interpolation_order_z'] print("starting preprocessing generator") preprocessing = preprocess_multithreaded(trainer, list_of_lists, cleaned_output_files, num_threads_preprocessing, segs_from_prev_stage) print("starting prediction...") for preprocessed in preprocessing: print("getting data from preprocessor") output_filename, (d, dct) = preprocessed print("got something") if isinstance(d, str): print("what I got is a string, so I need to load a file") data = np.load(d) os.remove(d) d = data # preallocate the output arrays # same dtype as the return value in predict_preprocessed_data_return_seg_and_softmax (saves time) softmax_aggr = None # np.zeros((trainer.num_classes, *d.shape[1:]), dtype=np.float16) all_seg_outputs = np.zeros((len(params), *d.shape[1:]), dtype=int) print("predicting", output_filename) for i, p in enumerate(params): trainer.load_checkpoint_ram(p, False) res = trainer.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=do_tta, mirror_axes=trainer.data_aug_params['mirror_axes'], use_sliding_window=True, step_size=step_size, use_gaussian=True, all_in_gpu=all_in_gpu, mixed_precision=mixed_precision) if len(params) > 1: # otherwise we dont need this and we can save ourselves the time it takes to copy that print("aggregating softmax") if softmax_aggr is None: softmax_aggr = res[1] else: softmax_aggr += res[1] all_seg_outputs[i] = res[0] print("obtaining segmentation map") if len(params) > 1: # we dont need to normalize the softmax by 1 / len(params) because this would not change the outcome of the argmax seg = softmax_aggr.argmax(0) else: seg = all_seg_outputs[0] print("applying transpose_backward") transpose_forward = trainer.plans.get('transpose_forward') if transpose_forward is not None: transpose_backward = trainer.plans.get('transpose_backward') seg = seg.transpose([i for i in transpose_backward]) if hasattr(trainer, 'regions_class_order'): region_class_order = trainer.regions_class_order else: region_class_order = None assert region_class_order is None, "predict_cases_fast can only work with regular softmax predictions " \ "and is therefore unable to handle trainer classes with region_class_order" print("initializing segmentation export") results.append(pool.starmap_async(save_segmentation_nifti, ((seg, output_filename, dct, interpolation_order, force_separate_z, interpolation_order_z),) )) print("done") print("inference done. Now waiting for the segmentation export to finish...") _ = [i.get() for i in results] # now apply postprocessing # first load the postprocessing properties if they are present. Else raise a well visible warning if not disable_postprocessing: results = [] pp_file = join(model, "postprocessing.json") if isfile(pp_file): print("postprocessing...") shutil.copy(pp_file, os.path.dirname(output_filenames[0])) # for_which_classes stores for which of the classes everything but the largest connected component needs to be # removed for_which_classes, min_valid_obj_size = load_postprocessing(pp_file) results.append(pool.starmap_async(load_remove_save, zip(output_filenames, output_filenames, [for_which_classes] * len(output_filenames), [min_valid_obj_size] * len(output_filenames)))) _ = [i.get() for i in results] else: print("WARNING! Cannot run postprocessing because the postprocessing file is missing. Make sure to run " "consolidate_folds in the output folder of the model first!\nThe folder you need to run this in is " "%s" % model) pool.close() pool.join() def predict_cases_fastest(model, list_of_lists, output_filenames, folds, num_threads_preprocessing, num_threads_nifti_save, segs_from_prev_stage=None, do_tta=True, mixed_precision=True, overwrite_existing=False, all_in_gpu=False, step_size=0.5, checkpoint_name="model_final_checkpoint", disable_postprocessing: bool = False): assert len(list_of_lists) == len(output_filenames) if segs_from_prev_stage is not None: assert len(segs_from_prev_stage) == len(output_filenames) pool = Pool(num_threads_nifti_save) results = [] cleaned_output_files = [] for o in output_filenames: dr, f = os.path.split(o) if len(dr) > 0: maybe_mkdir_p(dr) if not f.endswith(".nii.gz"): f, _ = os.path.splitext(f) f = f + ".nii.gz" cleaned_output_files.append(join(dr, f)) if not overwrite_existing: print("number of cases:", len(list_of_lists)) not_done_idx = [i for i, j in enumerate(cleaned_output_files) if not isfile(j)] cleaned_output_files = [cleaned_output_files[i] for i in not_done_idx] list_of_lists = [list_of_lists[i] for i in not_done_idx] if segs_from_prev_stage is not None: segs_from_prev_stage = [segs_from_prev_stage[i] for i in not_done_idx] print("number of cases that still need to be predicted:", len(cleaned_output_files)) print("emptying cuda cache") torch.cuda.empty_cache() print("loading parameters for folds,", folds) trainer, params = load_model_and_checkpoint_files(model, folds, mixed_precision=mixed_precision, checkpoint_name=checkpoint_name) print("starting preprocessing generator") preprocessing = preprocess_multithreaded(trainer, list_of_lists, cleaned_output_files, num_threads_preprocessing, segs_from_prev_stage) print("starting prediction...") for preprocessed in preprocessing: print("getting data from preprocessor") output_filename, (d, dct) = preprocessed print("got something") if isinstance(d, str): print("what I got is a string, so I need to load a file") data = np.load(d) os.remove(d) d = data # preallocate the output arrays # same dtype as the return value in predict_preprocessed_data_return_seg_and_softmax (saves time) all_softmax_outputs = np.zeros((len(params), trainer.num_classes, *d.shape[1:]), dtype=np.float16) all_seg_outputs = np.zeros((len(params), *d.shape[1:]), dtype=int) print("predicting", output_filename) for i, p in enumerate(params): trainer.load_checkpoint_ram(p, False) res = trainer.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=do_tta, mirror_axes=trainer.data_aug_params['mirror_axes'], use_sliding_window=True, step_size=step_size, use_gaussian=True, all_in_gpu=all_in_gpu, mixed_precision=mixed_precision) if len(params) > 1: # otherwise we dont need this and we can save ourselves the time it takes to copy that all_softmax_outputs[i] = res[1] all_seg_outputs[i] = res[0] if hasattr(trainer, 'regions_class_order'): region_class_order = trainer.regions_class_order else: region_class_order = None assert region_class_order is None, "predict_cases_fastest can only work with regular softmax predictions " \ "and is therefore unable to handle trainer classes with region_class_order" print("aggregating predictions") if len(params) > 1: softmax_mean = np.mean(all_softmax_outputs, 0) seg = softmax_mean.argmax(0) else: seg = all_seg_outputs[0] print("applying transpose_backward") transpose_forward = trainer.plans.get('transpose_forward') if transpose_forward is not None: transpose_backward = trainer.plans.get('transpose_backward') seg = seg.transpose([i for i in transpose_backward]) print("initializing segmentation export") results.append(pool.starmap_async(save_segmentation_nifti, ((seg, output_filename, dct, 0, None),) )) print("done") print("inference done. Now waiting for the segmentation export to finish...") _ = [i.get() for i in results] # now apply postprocessing # first load the postprocessing properties if they are present. Else raise a well visible warning if not disable_postprocessing: results = [] pp_file = join(model, "postprocessing.json") if isfile(pp_file): print("postprocessing...") shutil.copy(pp_file, os.path.dirname(output_filenames[0])) # for_which_classes stores for which of the classes everything but the largest connected component needs to be # removed for_which_classes, min_valid_obj_size = load_postprocessing(pp_file) results.append(pool.starmap_async(load_remove_save, zip(output_filenames, output_filenames, [for_which_classes] * len(output_filenames), [min_valid_obj_size] * len(output_filenames)))) _ = [i.get() for i in results] else: print("WARNING! Cannot run postprocessing because the postprocessing file is missing. Make sure to run " "consolidate_folds in the output folder of the model first!\nThe folder you need to run this in is " "%s" % model) pool.close() pool.join() def check_input_folder_and_return_caseIDs(input_folder, expected_num_modalities): print("This model expects %d input modalities for each image" % expected_num_modalities) files = subfiles(input_folder, suffix=".nii.gz", join=False, sort=True) maybe_case_ids = np.unique([i[:-12] for i in files]) remaining = deepcopy(files) missing = [] assert len(files) > 0, "input folder did not contain any images (expected to find .nii.gz file endings)" # now check if all required files are present and that no unexpected files are remaining for c in maybe_case_ids: for n in range(expected_num_modalities): expected_output_file = c + "_%04.0d.nii.gz" % n if not isfile(join(input_folder, expected_output_file)): missing.append(expected_output_file) else: remaining.remove(expected_output_file) print("Found %d unique case ids, here are some examples:" % len(maybe_case_ids), np.random.choice(maybe_case_ids, min(len(maybe_case_ids), 10))) print("If they don't look right, make sure to double check your filenames. They must end with _0000.nii.gz etc") if len(remaining) > 0: print("found %d unexpected remaining files in the folder. Here are some examples:" % len(remaining), np.random.choice(remaining, min(len(remaining), 10))) if len(missing) > 0: print("Some files are missing:") print(missing) raise RuntimeError("missing files in input_folder") return maybe_case_ids def predict_from_folder(model: str, input_folder: str, output_folder: str, folds: Union[Tuple[int], List[int]], save_npz: bool, num_threads_preprocessing: int, num_threads_nifti_save: int, lowres_segmentations: Union[str, None], part_id: int, num_parts: int, tta: bool, mixed_precision: bool = True, overwrite_existing: bool = True, mode: str = 'normal', overwrite_all_in_gpu: bool = None, step_size: float = 0.5, checkpoint_name: str = "model_final_checkpoint", segmentation_export_kwargs: dict = None, disable_postprocessing: bool = False): """ here we use the standard naming scheme to generate list_of_lists and output_files needed by predict_cases :param model: :param input_folder: :param output_folder: :param folds: :param save_npz: :param num_threads_preprocessing: :param num_threads_nifti_save: :param lowres_segmentations: :param part_id: :param num_parts: :param tta: :param mixed_precision: :param overwrite_existing: if not None then it will be overwritten with whatever is in there. None is default (no overwrite) :return: """ maybe_mkdir_p(output_folder) shutil.copy(join(model, 'plans.pkl'), output_folder) assert isfile(join(model, "plans.pkl")), "Folder with saved model weights must contain a plans.pkl file" expected_num_modalities = load_pickle(join(model, "plans.pkl"))['num_modalities'] # check input folder integrity case_ids = check_input_folder_and_return_caseIDs(input_folder, expected_num_modalities) output_files = [join(output_folder, i + ".nii.gz") for i in case_ids] all_files = subfiles(input_folder, suffix=".nii.gz", join=False, sort=True) list_of_lists = [[join(input_folder, i) for i in all_files if i[:len(j)].startswith(j) and len(i) == (len(j) + 12)] for j in case_ids] if lowres_segmentations is not None: assert isdir(lowres_segmentations), "if lowres_segmentations is not None then it must point to a directory" lowres_segmentations = [join(lowres_segmentations, i + ".nii.gz") for i in case_ids] assert all([isfile(i) for i in lowres_segmentations]), "not all lowres_segmentations files are present. " \ "(I was searching for case_id.nii.gz in that folder)" lowres_segmentations = lowres_segmentations[part_id::num_parts] else: lowres_segmentations = None if mode == "normal": if overwrite_all_in_gpu is None: all_in_gpu = False else: all_in_gpu = overwrite_all_in_gpu return predict_cases(model, list_of_lists[part_id::num_parts], output_files[part_id::num_parts], folds, save_npz, num_threads_preprocessing, num_threads_nifti_save, lowres_segmentations, tta, mixed_precision=mixed_precision, overwrite_existing=overwrite_existing, all_in_gpu=all_in_gpu, step_size=step_size, checkpoint_name=checkpoint_name, segmentation_export_kwargs=segmentation_export_kwargs, disable_postprocessing=disable_postprocessing) elif mode == "fast": if overwrite_all_in_gpu is None: all_in_gpu = False else: all_in_gpu = overwrite_all_in_gpu assert save_npz is False return predict_cases_fast(model, list_of_lists[part_id::num_parts], output_files[part_id::num_parts], folds, num_threads_preprocessing, num_threads_nifti_save, lowres_segmentations, tta, mixed_precision=mixed_precision, overwrite_existing=overwrite_existing, all_in_gpu=all_in_gpu, step_size=step_size, checkpoint_name=checkpoint_name, segmentation_export_kwargs=segmentation_export_kwargs, disable_postprocessing=disable_postprocessing) elif mode == "fastest": if overwrite_all_in_gpu is None: all_in_gpu = False else: all_in_gpu = overwrite_all_in_gpu assert save_npz is False return predict_cases_fastest(model, list_of_lists[part_id::num_parts], output_files[part_id::num_parts], folds, num_threads_preprocessing, num_threads_nifti_save, lowres_segmentations, tta, mixed_precision=mixed_precision, overwrite_existing=overwrite_existing, all_in_gpu=all_in_gpu, step_size=step_size, checkpoint_name=checkpoint_name, disable_postprocessing=disable_postprocessing) else: raise ValueError("unrecognized mode. Must be normal, fast or fastest") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-i", '--input_folder', help="Must contain all modalities for each patient in the correct" " order (same as training). Files must be named " "CASENAME_XXXX.nii.gz where XXXX is the modality " "identifier (0000, 0001, etc)", required=True) parser.add_argument('-o', "--output_folder", required=True, help="folder for saving predictions") parser.add_argument('-m', '--model_output_folder', help='model output folder. Will automatically discover the folds ' 'that were ' 'run and use those as an ensemble', required=True) parser.add_argument('-f', '--folds', nargs='+', default='None', help="folds to use for prediction. Default is None " "which means that folds will be detected " "automatically in the model output folder") parser.add_argument('-z', '--save_npz', required=False, action='store_true', help="use this if you want to ensemble" " these predictions with those of" " other models. Softmax " "probabilities will be saved as " "compresed numpy arrays in " "output_folder and can be merged " "between output_folders with " "merge_predictions.py") parser.add_argument('-l', '--lowres_segmentations', required=False, default='None', help="if model is the highres " "stage of the cascade then you need to use -l to specify where the segmentations of the " "corresponding lowres unet are. Here they are required to do a prediction") parser.add_argument("--part_id", type=int, required=False, default=0, help="Used to parallelize the prediction of " "the folder over several GPUs. If you " "want to use n GPUs to predict this " "folder you need to run this command " "n times with --part_id=0, ... n-1 and " "--num_parts=n (each with a different " "GPU (for example via " "CUDA_VISIBLE_DEVICES=X)") parser.add_argument("--num_parts", type=int, required=False, default=1, help="Used to parallelize the prediction of " "the folder over several GPUs. If you " "want to use n GPUs to predict this " "folder you need to run this command " "n times with --part_id=0, ... n-1 and " "--num_parts=n (each with a different " "GPU (via " "CUDA_VISIBLE_DEVICES=X)") parser.add_argument("--num_threads_preprocessing", required=False, default=6, type=int, help= "Determines many background processes will be used for data preprocessing. Reduce this if you " "run into out of memory (RAM) problems. Default: 6") parser.add_argument("--num_threads_nifti_save", required=False, default=2, type=int, help= "Determines many background processes will be used for segmentation export. Reduce this if you " "run into out of memory (RAM) problems. Default: 2") parser.add_argument("--tta", required=False, type=int, default=1, help="Set to 0 to disable test time data " "augmentation (speedup of factor " "4(2D)/8(3D)), " "lower quality segmentations") parser.add_argument("--overwrite_existing", required=False, type=int, default=1, help="Set this to 0 if you need " "to resume a previous " "prediction. Default: 1 " "(=existing segmentations " "in output_folder will be " "overwritten)") parser.add_argument("--mode", type=str, default="normal", required=False) parser.add_argument("--all_in_gpu", type=str, default="None", required=False, help="can be None, False or True") parser.add_argument("--step_size", type=float, default=0.5, required=False, help="don't touch") # parser.add_argument("--interp_order", required=False, default=3, type=int, # help="order of interpolation for segmentations, has no effect if mode=fastest") # parser.add_argument("--interp_order_z", required=False, default=0, type=int, # help="order of interpolation along z is z is done differently") # parser.add_argument("--force_separate_z", required=False, default="None", type=str, # help="force_separate_z resampling. Can be None, True or False, has no effect if mode=fastest") parser.add_argument('--disable_mixed_precision', default=False, action='store_true', required=False, help='Predictions are done with mixed precision by default. This improves speed and reduces ' 'the required vram. If you want to disable mixed precision you can set this flag. Note ' 'that yhis is not recommended (mixed precision is ~2x faster!)') args = parser.parse_args() input_folder = args.input_folder output_folder = args.output_folder part_id = args.part_id num_parts = args.num_parts model = args.model_output_folder folds = args.folds save_npz = args.save_npz lowres_segmentations = args.lowres_segmentations num_threads_preprocessing = args.num_threads_preprocessing num_threads_nifti_save = args.num_threads_nifti_save tta = args.tta step_size = args.step_size # interp_order = args.interp_order # interp_order_z = args.interp_order_z # force_separate_z = args.force_separate_z # if force_separate_z == "None": # force_separate_z = None # elif force_separate_z == "False": # force_separate_z = False # elif force_separate_z == "True": # force_separate_z = True # else: # raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z) overwrite = args.overwrite_existing mode = args.mode all_in_gpu = args.all_in_gpu if lowres_segmentations == "None": lowres_segmentations = None if isinstance(folds, list): if folds[0] == 'all' and len(folds) == 1: pass else: folds = [int(i) for i in folds] elif folds == "None": folds = None else: raise ValueError("Unexpected value for argument folds") if tta == 0: tta = False elif tta == 1: tta = True else: raise ValueError("Unexpected value for tta, Use 1 or 0") if overwrite == 0: overwrite = False elif overwrite == 1: overwrite = True else: raise ValueError("Unexpected value for overwrite, Use 1 or 0") assert all_in_gpu in ['None', 'False', 'True'] if all_in_gpu == "None": all_in_gpu = None elif all_in_gpu == "True": all_in_gpu = True elif all_in_gpu == "False": all_in_gpu = False predict_from_folder(model, input_folder, output_folder, folds, save_npz, num_threads_preprocessing, num_threads_nifti_save, lowres_segmentations, part_id, num_parts, tta, mixed_precision=not args.disable_mixed_precision, overwrite_existing=overwrite, mode=mode, overwrite_all_in_gpu=all_in_gpu, step_size=step_size)