# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from multiprocessing.pool import Pool from time import sleep import matplotlib from nnunet.configuration import default_num_threads from nnunet.postprocessing.connected_components import determine_postprocessing from nnunet.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation from nnunet.training.dataloading.dataset_loading import DataLoader3D, unpack_dataset from nnunet.evaluation.evaluator import aggregate_scores from nnunet.network_architecture.neural_network import SegmentationNetwork from nnunet.paths import network_training_output_dir from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax from batchgenerators.utilities.file_and_folder_operations import * import numpy as np from nnunet.training.loss_functions.deep_supervision import MultipleOutputLoss2 from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2 from nnunet.utilities.one_hot_encoding import to_one_hot import shutil from torch import nn matplotlib.use("agg") class nnUNetTrainerV2CascadeFullRes(nnUNetTrainerV2): def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, unpack_data=True, deterministic=True, previous_trainer="nnUNetTrainerV2", fp16=False): super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, fp16) self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, previous_trainer, fp16) if self.output_folder is not None: task = self.output_folder.split("/")[-3] plans_identifier = self.output_folder.split("/")[-2].split("__")[-1] folder_with_segs_prev_stage = join(network_training_output_dir, "3d_lowres", task, previous_trainer + "__" + plans_identifier, "pred_next_stage") self.folder_with_segs_from_prev_stage = folder_with_segs_prev_stage # Do not put segs_prev_stage into self.output_folder as we need to unpack them for performance and we # don't want to do that in self.output_folder because that one is located on some network drive. else: self.folder_with_segs_from_prev_stage = None def do_split(self): super().do_split() for k in self.dataset: self.dataset[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage, k + "_segFromPrevStage.npz") assert isfile(self.dataset[k]['seg_from_prev_stage_file']), \ "seg from prev stage missing: %s. " \ "Please run all 5 folds of the 3d_lowres configuration of this " \ "task!" % (self.dataset[k]['seg_from_prev_stage_file']) for k in self.dataset_val: self.dataset_val[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage, k + "_segFromPrevStage.npz") for k in self.dataset_tr: self.dataset_tr[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage, k + "_segFromPrevStage.npz") def get_basic_generators(self): self.load_dataset() self.do_split() if self.threeD: dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size, True, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides) dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, True, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides) else: raise NotImplementedError("2D has no cascade") return dl_tr, dl_val def process_plans(self, plans): super().process_plans(plans) self.num_input_channels += (self.num_classes - 1) # for seg from prev stage def setup_DA_params(self): super().setup_DA_params() self.data_aug_params["num_cached_per_thread"] = 2 self.data_aug_params['move_last_seg_chanel_to_data'] = True self.data_aug_params['cascade_do_cascade_augmentations'] = True self.data_aug_params['cascade_random_binary_transform_p'] = 0.4 self.data_aug_params['cascade_random_binary_transform_p_per_label'] = 1 self.data_aug_params['cascade_random_binary_transform_size'] = (1, 8) self.data_aug_params['cascade_remove_conn_comp_p'] = 0.2 self.data_aug_params['cascade_remove_conn_comp_max_size_percent_threshold'] = 0.15 self.data_aug_params['cascade_remove_conn_comp_fill_with_other_class_p'] = 0.0 # we have 2 channels now because the segmentation from the previous stage is stored in 'seg' as well until it # is moved to 'data' at the end self.data_aug_params['selected_seg_channels'] = [0, 1] # needed for converting the segmentation from the previous stage to one hot self.data_aug_params['all_segmentation_labels'] = list(range(1, self.num_classes)) def initialize(self, training=True, force_load_plans=False): """ For prediction of test cases just set training=False, this will prevent loading of training data and training batchgenerator initialization :param training: :return: """ if not self.was_initialized: if force_load_plans or (self.plans is None): self.load_plans_file() self.process_plans(self.plans) self.setup_DA_params() ################# Here we wrap the loss for deep supervision ############ # we need to know the number of outputs of the network net_numpool = len(self.net_num_pool_op_kernel_sizes) # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss weights = np.array([1 / (2 ** i) for i in range(net_numpool)]) # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 mask = np.array([True if i < net_numpool - 1 else False for i in range(net_numpool)]) weights[~mask] = 0 weights = weights / weights.sum() self.ds_loss_weights = weights # now wrap the loss self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights) ################# END ################### self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) if training: if not isdir(self.folder_with_segs_from_prev_stage): raise RuntimeError( "Cannot run final stage of cascade. Run corresponding 3d_lowres first and predict the " "segmentations for the next stage") self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: print("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) print("done") else: print( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") self.tr_gen, self.val_gen = get_moreDA_augmentation(self.dl_tr, self.dl_val, self.data_aug_params[ 'patch_size_for_spatialtransform'], self.data_aug_params, deep_supervision_scales=self.deep_supervision_scales, pin_memory=self.pin_memory) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel)) else: self.print_to_log_file('self.was_initialized is True, not running self.initialize again') self.was_initialized = True def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True, validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False, segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True): assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)" current_mode = self.network.training self.network.eval() # save whether network is in deep supervision mode or not ds = self.network.do_ds # disable deep supervision self.network.do_ds = False if segmentation_export_kwargs is None: if 'segmentation_export_params' in self.plans.keys(): force_separate_z = self.plans['segmentation_export_params']['force_separate_z'] interpolation_order = self.plans['segmentation_export_params']['interpolation_order'] interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z'] else: force_separate_z = None interpolation_order = 1 interpolation_order_z = 0 else: force_separate_z = segmentation_export_kwargs['force_separate_z'] interpolation_order = segmentation_export_kwargs['interpolation_order'] interpolation_order_z = segmentation_export_kwargs['interpolation_order_z'] if self.dataset_val is None: self.load_dataset() self.do_split() output_folder = join(self.output_folder, validation_folder_name) maybe_mkdir_p(output_folder) # this is for debug purposes my_input_args = {'do_mirroring': do_mirroring, 'use_sliding_window': use_sliding_window, 'step': step_size, 'save_softmax': save_softmax, 'use_gaussian': use_gaussian, 'overwrite': overwrite, 'validation_folder_name': validation_folder_name, 'debug': debug, 'all_in_gpu': all_in_gpu, 'segmentation_export_kwargs': segmentation_export_kwargs, } save_json(my_input_args, join(output_folder, "validation_args.json")) if do_mirroring: if not self.data_aug_params['do_mirror']: raise RuntimeError("We did not train with mirroring so you cannot do inference with mirroring enabled") mirror_axes = self.data_aug_params['mirror_axes'] else: mirror_axes = () pred_gt_tuples = [] export_pool = Pool(default_num_threads) results = [] for k in self.dataset_val.keys(): properties = load_pickle(self.dataset[k]['properties_file']) fname = properties['list_of_data_files'][0].split("/")[-1][:-12] if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \ (save_softmax and not isfile(join(output_folder, fname + ".npz"))): data = np.load(self.dataset[k]['data_file'])['data'] # concat segmentation of previous step seg_from_prev_stage = np.load(join(self.folder_with_segs_from_prev_stage, k + "_segFromPrevStage.npz"))['data'][None] print(k, data.shape) data[-1][data[-1] == -1] = 0 data_for_net = np.concatenate((data[:-1], to_one_hot(seg_from_prev_stage[0], range(1, self.num_classes)))) softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data_for_net, do_mirroring=do_mirroring, mirror_axes=mirror_axes, use_sliding_window=use_sliding_window, step_size=step_size, use_gaussian=use_gaussian, all_in_gpu=all_in_gpu, mixed_precision=self.fp16)[1] softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in self.transpose_backward]) if save_softmax: softmax_fname = join(output_folder, fname + ".npz") else: softmax_fname = None """There is a problem with python process communication that prevents us from communicating obejcts larger than 2 GB between processes (basically when the length of the pickle string that will be sent is communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either filename or np.ndarray and will handle this automatically""" if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save np.save(join(output_folder, fname + ".npy"), softmax_pred) softmax_pred = join(output_folder, fname + ".npy") results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax, ((softmax_pred, join(output_folder, fname + ".nii.gz"), properties, interpolation_order, None, None, None, softmax_fname, None, force_separate_z, interpolation_order_z), ) ) ) pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"), join(self.gt_niftis_folder, fname + ".nii.gz")]) _ = [i.get() for i in results] self.print_to_log_file("finished prediction") # evaluate raw predictions self.print_to_log_file("evaluation of raw predictions") task = self.dataset_directory.split("/")[-1] job_name = self.experiment_name _ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)), json_output_file=join(output_folder, "summary.json"), json_name=job_name + " val tiled %s" % (str(use_sliding_window)), json_author="Fabian", json_task=task, num_threads=default_num_threads) if run_postprocessing_on_folds: # in the old nnunet we would stop here. Now we add a postprocessing. This postprocessing can remove everything # except the largest connected component for each class. To see if this improves results, we do this for all # classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will # have this applied during inference as well self.print_to_log_file("determining postprocessing") determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name, final_subf_name=validation_folder_name + "_postprocessed", debug=debug) # after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed" # They are always in that folder, even if no postprocessing as applied! # detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another # postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be # done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to # be used later gt_nifti_folder = join(self.output_folder_base, "gt_niftis") maybe_mkdir_p(gt_nifti_folder) for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"): success = False attempts = 0 e = None while not success and attempts < 10: try: shutil.copy(f, gt_nifti_folder) success = True except OSError as e: attempts += 1 sleep(1) if not success: print("Could not copy gt nifti file %s into folder %s" % (f, gt_nifti_folder)) if e is not None: raise e # restore network deep supervision mode self.network.train(current_mode) self.network.do_ds = ds