|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from multiprocessing.pool import Pool |
|
from time import sleep |
|
|
|
import matplotlib |
|
from nnunet.postprocessing.connected_components import determine_postprocessing |
|
from nnunet.training.data_augmentation.default_data_augmentation import get_default_augmentation |
|
from nnunet.training.dataloading.dataset_loading import DataLoader3D, unpack_dataset |
|
from nnunet.evaluation.evaluator import aggregate_scores |
|
from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer |
|
from nnunet.network_architecture.neural_network import SegmentationNetwork |
|
from nnunet.paths import network_training_output_dir |
|
from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax |
|
from batchgenerators.utilities.file_and_folder_operations import * |
|
import numpy as np |
|
from nnunet.utilities.one_hot_encoding import to_one_hot |
|
import shutil |
|
|
|
matplotlib.use("agg") |
|
|
|
|
|
class nnUNetTrainerCascadeFullRes(nnUNetTrainer): |
|
def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, |
|
unpack_data=True, deterministic=True, previous_trainer="nnUNetTrainer", fp16=False): |
|
super(nnUNetTrainerCascadeFullRes, self).__init__(plans_file, fold, output_folder, dataset_directory, |
|
batch_dice, stage, unpack_data, deterministic, fp16) |
|
self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, |
|
deterministic, previous_trainer, fp16) |
|
|
|
if self.output_folder is not None: |
|
task = self.output_folder.split("/")[-3] |
|
plans_identifier = self.output_folder.split("/")[-2].split("__")[-1] |
|
|
|
folder_with_segs_prev_stage = join(network_training_output_dir, "3d_lowres", |
|
task, previous_trainer + "__" + plans_identifier, "pred_next_stage") |
|
if not isdir(folder_with_segs_prev_stage): |
|
raise RuntimeError( |
|
"Cannot run final stage of cascade. Run corresponding 3d_lowres first and predict the " |
|
"segmentations for the next stage") |
|
self.folder_with_segs_from_prev_stage = folder_with_segs_prev_stage |
|
|
|
|
|
else: |
|
self.folder_with_segs_from_prev_stage = None |
|
|
|
def do_split(self): |
|
super(nnUNetTrainerCascadeFullRes, self).do_split() |
|
for k in self.dataset: |
|
self.dataset[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage, |
|
k + "_segFromPrevStage.npz") |
|
assert isfile(self.dataset[k]['seg_from_prev_stage_file']), \ |
|
"seg from prev stage missing: %s" % (self.dataset[k]['seg_from_prev_stage_file']) |
|
for k in self.dataset_val: |
|
self.dataset_val[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage, |
|
k + "_segFromPrevStage.npz") |
|
for k in self.dataset_tr: |
|
self.dataset_tr[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage, |
|
k + "_segFromPrevStage.npz") |
|
|
|
def get_basic_generators(self): |
|
self.load_dataset() |
|
self.do_split() |
|
if self.threeD: |
|
dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size, |
|
True, oversample_foreground_percent=self.oversample_foreground_percent) |
|
dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, True, |
|
oversample_foreground_percent=self.oversample_foreground_percent) |
|
else: |
|
raise NotImplementedError |
|
return dl_tr, dl_val |
|
|
|
def process_plans(self, plans): |
|
super(nnUNetTrainerCascadeFullRes, self).process_plans(plans) |
|
self.num_input_channels += (self.num_classes - 1) |
|
|
|
def setup_DA_params(self): |
|
super().setup_DA_params() |
|
self.data_aug_params['move_last_seg_chanel_to_data'] = True |
|
self.data_aug_params['cascade_do_cascade_augmentations'] = True |
|
|
|
self.data_aug_params['cascade_random_binary_transform_p'] = 0.4 |
|
self.data_aug_params['cascade_random_binary_transform_p_per_label'] = 1 |
|
self.data_aug_params['cascade_random_binary_transform_size'] = (1, 8) |
|
|
|
self.data_aug_params['cascade_remove_conn_comp_p'] = 0.2 |
|
self.data_aug_params['cascade_remove_conn_comp_max_size_percent_threshold'] = 0.15 |
|
self.data_aug_params['cascade_remove_conn_comp_fill_with_other_class_p'] = 0.0 |
|
|
|
|
|
|
|
self.data_aug_params['selected_seg_channels'] = [0, 1] |
|
|
|
self.data_aug_params['all_segmentation_labels'] = list(range(1, self.num_classes)) |
|
|
|
def initialize(self, training=True, force_load_plans=False): |
|
""" |
|
For prediction of test cases just set training=False, this will prevent loading of training data and |
|
training batchgenerator initialization |
|
:param training: |
|
:return: |
|
""" |
|
if force_load_plans or (self.plans is None): |
|
self.load_plans_file() |
|
|
|
self.process_plans(self.plans) |
|
|
|
self.setup_DA_params() |
|
|
|
self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + |
|
"_stage%d" % self.stage) |
|
if training: |
|
self.setup_DA_params() |
|
|
|
if self.folder_with_preprocessed_data is not None: |
|
self.dl_tr, self.dl_val = self.get_basic_generators() |
|
|
|
if self.unpack_data: |
|
print("unpacking dataset") |
|
unpack_dataset(self.folder_with_preprocessed_data) |
|
print("done") |
|
else: |
|
print( |
|
"INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " |
|
"will wait all winter for your model to finish!") |
|
|
|
self.tr_gen, self.val_gen = get_default_augmentation(self.dl_tr, self.dl_val, |
|
self.data_aug_params[ |
|
'patch_size_for_spatialtransform'], |
|
self.data_aug_params) |
|
self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys()))) |
|
self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys()))) |
|
else: |
|
pass |
|
self.initialize_network() |
|
assert isinstance(self.network, SegmentationNetwork) |
|
self.was_initialized = True |
|
|
|
def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, |
|
step_size: float = 0.5, |
|
save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True, |
|
validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False, |
|
segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True): |
|
|
|
current_mode = self.network.training |
|
self.network.eval() |
|
|
|
assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)" |
|
if self.dataset_val is None: |
|
self.load_dataset() |
|
self.do_split() |
|
|
|
if segmentation_export_kwargs is None: |
|
if 'segmentation_export_params' in self.plans.keys(): |
|
force_separate_z = self.plans['segmentation_export_params']['force_separate_z'] |
|
interpolation_order = self.plans['segmentation_export_params']['interpolation_order'] |
|
interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z'] |
|
else: |
|
force_separate_z = None |
|
interpolation_order = 1 |
|
interpolation_order_z = 0 |
|
else: |
|
force_separate_z = segmentation_export_kwargs['force_separate_z'] |
|
interpolation_order = segmentation_export_kwargs['interpolation_order'] |
|
interpolation_order_z = segmentation_export_kwargs['interpolation_order_z'] |
|
|
|
output_folder = join(self.output_folder, validation_folder_name) |
|
maybe_mkdir_p(output_folder) |
|
|
|
if do_mirroring: |
|
mirror_axes = self.data_aug_params['mirror_axes'] |
|
else: |
|
mirror_axes = () |
|
|
|
pred_gt_tuples = [] |
|
|
|
export_pool = Pool(2) |
|
results = [] |
|
|
|
transpose_backward = self.plans.get('transpose_backward') |
|
|
|
for k in self.dataset_val.keys(): |
|
properties = load_pickle(self.dataset[k]['properties_file']) |
|
data = np.load(self.dataset[k]['data_file'])['data'] |
|
|
|
|
|
seg_from_prev_stage = np.load(join(self.folder_with_segs_from_prev_stage, |
|
k + "_segFromPrevStage.npz"))['data'][None] |
|
|
|
print(data.shape) |
|
data[-1][data[-1] == -1] = 0 |
|
data_for_net = np.concatenate((data[:-1], to_one_hot(seg_from_prev_stage[0], range(1, self.num_classes)))) |
|
|
|
softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data_for_net, |
|
do_mirroring=do_mirroring, |
|
mirror_axes=mirror_axes, |
|
use_sliding_window=use_sliding_window, |
|
step_size=step_size, |
|
use_gaussian=use_gaussian, |
|
all_in_gpu=all_in_gpu, |
|
mixed_precision=self.fp16)[1] |
|
|
|
if transpose_backward is not None: |
|
transpose_backward = self.plans.get('transpose_backward') |
|
softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in transpose_backward]) |
|
|
|
fname = properties['list_of_data_files'][0].split("/")[-1][:-12] |
|
|
|
if save_softmax: |
|
softmax_fname = join(output_folder, fname + ".npz") |
|
else: |
|
softmax_fname = None |
|
|
|
"""There is a problem with python process communication that prevents us from communicating obejcts |
|
larger than 2 GB between processes (basically when the length of the pickle string that will be sent is |
|
communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long |
|
enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually |
|
patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will |
|
then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either |
|
filename or np.ndarray and will handle this automatically""" |
|
if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): |
|
np.save(fname + ".npy", softmax_pred) |
|
softmax_pred = fname + ".npy" |
|
|
|
results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax, |
|
((softmax_pred, join(output_folder, fname + ".nii.gz"), |
|
properties, interpolation_order, self.regions_class_order, |
|
None, None, |
|
softmax_fname, None, force_separate_z, |
|
interpolation_order_z), |
|
) |
|
) |
|
) |
|
|
|
pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"), |
|
join(self.gt_niftis_folder, fname + ".nii.gz")]) |
|
|
|
_ = [i.get() for i in results] |
|
|
|
task = self.dataset_directory.split("/")[-1] |
|
job_name = self.experiment_name |
|
_ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)), |
|
json_output_file=join(output_folder, "summary.json"), json_name=job_name, |
|
json_author="Fabian", json_description="", |
|
json_task=task) |
|
|
|
if run_postprocessing_on_folds: |
|
|
|
|
|
|
|
|
|
self.print_to_log_file("determining postprocessing") |
|
determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name, |
|
final_subf_name=validation_folder_name + "_postprocessed", debug=debug) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gt_nifti_folder = join(self.output_folder_base, "gt_niftis") |
|
maybe_mkdir_p(gt_nifti_folder) |
|
for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"): |
|
success = False |
|
attempts = 0 |
|
while not success and attempts < 10: |
|
try: |
|
shutil.copy(f, gt_nifti_folder) |
|
success = True |
|
except OSError: |
|
attempts += 1 |
|
sleep(1) |
|
|
|
self.network.train(current_mode) |
|
export_pool.close() |
|
export_pool.join() |