ho11laqe's picture
init
ecf08bc
raw
history blame
No virus
16.1 kB
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from multiprocessing.pool import Pool
from time import sleep
import matplotlib
from nnunet.postprocessing.connected_components import determine_postprocessing
from nnunet.training.data_augmentation.default_data_augmentation import get_default_augmentation
from nnunet.training.dataloading.dataset_loading import DataLoader3D, unpack_dataset
from nnunet.evaluation.evaluator import aggregate_scores
from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
from nnunet.network_architecture.neural_network import SegmentationNetwork
from nnunet.paths import network_training_output_dir
from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax
from batchgenerators.utilities.file_and_folder_operations import *
import numpy as np
from nnunet.utilities.one_hot_encoding import to_one_hot
import shutil
matplotlib.use("agg")
class nnUNetTrainerCascadeFullRes(nnUNetTrainer):
def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
unpack_data=True, deterministic=True, previous_trainer="nnUNetTrainer", fp16=False):
super(nnUNetTrainerCascadeFullRes, self).__init__(plans_file, fold, output_folder, dataset_directory,
batch_dice, stage, unpack_data, deterministic, fp16)
self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
deterministic, previous_trainer, fp16)
if self.output_folder is not None:
task = self.output_folder.split("/")[-3]
plans_identifier = self.output_folder.split("/")[-2].split("__")[-1]
folder_with_segs_prev_stage = join(network_training_output_dir, "3d_lowres",
task, previous_trainer + "__" + plans_identifier, "pred_next_stage")
if not isdir(folder_with_segs_prev_stage):
raise RuntimeError(
"Cannot run final stage of cascade. Run corresponding 3d_lowres first and predict the "
"segmentations for the next stage")
self.folder_with_segs_from_prev_stage = folder_with_segs_prev_stage
# Do not put segs_prev_stage into self.output_folder as we need to unpack them for performance and we
# don't want to do that in self.output_folder because that one is located on some network drive.
else:
self.folder_with_segs_from_prev_stage = None
def do_split(self):
super(nnUNetTrainerCascadeFullRes, self).do_split()
for k in self.dataset:
self.dataset[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage,
k + "_segFromPrevStage.npz")
assert isfile(self.dataset[k]['seg_from_prev_stage_file']), \
"seg from prev stage missing: %s" % (self.dataset[k]['seg_from_prev_stage_file'])
for k in self.dataset_val:
self.dataset_val[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage,
k + "_segFromPrevStage.npz")
for k in self.dataset_tr:
self.dataset_tr[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage,
k + "_segFromPrevStage.npz")
def get_basic_generators(self):
self.load_dataset()
self.do_split()
if self.threeD:
dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,
True, oversample_foreground_percent=self.oversample_foreground_percent)
dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, True,
oversample_foreground_percent=self.oversample_foreground_percent)
else:
raise NotImplementedError
return dl_tr, dl_val
def process_plans(self, plans):
super(nnUNetTrainerCascadeFullRes, self).process_plans(plans)
self.num_input_channels += (self.num_classes - 1) # for seg from prev stage
def setup_DA_params(self):
super().setup_DA_params()
self.data_aug_params['move_last_seg_chanel_to_data'] = True
self.data_aug_params['cascade_do_cascade_augmentations'] = True
self.data_aug_params['cascade_random_binary_transform_p'] = 0.4
self.data_aug_params['cascade_random_binary_transform_p_per_label'] = 1
self.data_aug_params['cascade_random_binary_transform_size'] = (1, 8)
self.data_aug_params['cascade_remove_conn_comp_p'] = 0.2
self.data_aug_params['cascade_remove_conn_comp_max_size_percent_threshold'] = 0.15
self.data_aug_params['cascade_remove_conn_comp_fill_with_other_class_p'] = 0.0
# we have 2 channels now because the segmentation from the previous stage is stored in 'seg' as well until it
# is moved to 'data' at the end
self.data_aug_params['selected_seg_channels'] = [0, 1]
# needed for converting the segmentation from the previous stage to one hot
self.data_aug_params['all_segmentation_labels'] = list(range(1, self.num_classes))
def initialize(self, training=True, force_load_plans=False):
"""
For prediction of test cases just set training=False, this will prevent loading of training data and
training batchgenerator initialization
:param training:
:return:
"""
if force_load_plans or (self.plans is None):
self.load_plans_file()
self.process_plans(self.plans)
self.setup_DA_params()
self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
"_stage%d" % self.stage)
if training:
self.setup_DA_params()
if self.folder_with_preprocessed_data is not None:
self.dl_tr, self.dl_val = self.get_basic_generators()
if self.unpack_data:
print("unpacking dataset")
unpack_dataset(self.folder_with_preprocessed_data)
print("done")
else:
print(
"INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
"will wait all winter for your model to finish!")
self.tr_gen, self.val_gen = get_default_augmentation(self.dl_tr, self.dl_val,
self.data_aug_params[
'patch_size_for_spatialtransform'],
self.data_aug_params)
self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())))
self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())))
else:
pass
self.initialize_network()
assert isinstance(self.network, SegmentationNetwork)
self.was_initialized = True
def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True,
step_size: float = 0.5,
save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
current_mode = self.network.training
self.network.eval()
assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)"
if self.dataset_val is None:
self.load_dataset()
self.do_split()
if segmentation_export_kwargs is None:
if 'segmentation_export_params' in self.plans.keys():
force_separate_z = self.plans['segmentation_export_params']['force_separate_z']
interpolation_order = self.plans['segmentation_export_params']['interpolation_order']
interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z']
else:
force_separate_z = None
interpolation_order = 1
interpolation_order_z = 0
else:
force_separate_z = segmentation_export_kwargs['force_separate_z']
interpolation_order = segmentation_export_kwargs['interpolation_order']
interpolation_order_z = segmentation_export_kwargs['interpolation_order_z']
output_folder = join(self.output_folder, validation_folder_name)
maybe_mkdir_p(output_folder)
if do_mirroring:
mirror_axes = self.data_aug_params['mirror_axes']
else:
mirror_axes = ()
pred_gt_tuples = []
export_pool = Pool(2)
results = []
transpose_backward = self.plans.get('transpose_backward')
for k in self.dataset_val.keys():
properties = load_pickle(self.dataset[k]['properties_file'])
data = np.load(self.dataset[k]['data_file'])['data']
# concat segmentation of previous step
seg_from_prev_stage = np.load(join(self.folder_with_segs_from_prev_stage,
k + "_segFromPrevStage.npz"))['data'][None]
print(data.shape)
data[-1][data[-1] == -1] = 0
data_for_net = np.concatenate((data[:-1], to_one_hot(seg_from_prev_stage[0], range(1, self.num_classes))))
softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data_for_net,
do_mirroring=do_mirroring,
mirror_axes=mirror_axes,
use_sliding_window=use_sliding_window,
step_size=step_size,
use_gaussian=use_gaussian,
all_in_gpu=all_in_gpu,
mixed_precision=self.fp16)[1]
if transpose_backward is not None:
transpose_backward = self.plans.get('transpose_backward')
softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in transpose_backward])
fname = properties['list_of_data_files'][0].split("/")[-1][:-12]
if save_softmax:
softmax_fname = join(output_folder, fname + ".npz")
else:
softmax_fname = None
"""There is a problem with python process communication that prevents us from communicating obejcts
larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long
enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
filename or np.ndarray and will handle this automatically"""
if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save
np.save(fname + ".npy", softmax_pred)
softmax_pred = fname + ".npy"
results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax,
((softmax_pred, join(output_folder, fname + ".nii.gz"),
properties, interpolation_order, self.regions_class_order,
None, None,
softmax_fname, None, force_separate_z,
interpolation_order_z),
)
)
)
pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"),
join(self.gt_niftis_folder, fname + ".nii.gz")])
_ = [i.get() for i in results]
task = self.dataset_directory.split("/")[-1]
job_name = self.experiment_name
_ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)),
json_output_file=join(output_folder, "summary.json"), json_name=job_name,
json_author="Fabian", json_description="",
json_task=task)
if run_postprocessing_on_folds:
# in the old nnunet we would stop here. Now we add a postprocessing. This postprocessing can remove everything
# except the largest connected component for each class. To see if this improves results, we do this for all
# classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will
# have this applied during inference as well
self.print_to_log_file("determining postprocessing")
determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name,
final_subf_name=validation_folder_name + "_postprocessed", debug=debug)
# after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed"
# They are always in that folder, even if no postprocessing as applied!
# detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another
# postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be
# done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to
# be used later
gt_nifti_folder = join(self.output_folder_base, "gt_niftis")
maybe_mkdir_p(gt_nifti_folder)
for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"):
success = False
attempts = 0
while not success and attempts < 10:
try:
shutil.copy(f, gt_nifti_folder)
success = True
except OSError:
attempts += 1
sleep(1)
self.network.train(current_mode)
export_pool.close()
export_pool.join()