|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import shutil |
|
from collections import OrderedDict |
|
from multiprocessing import Pool |
|
from time import sleep |
|
from typing import Tuple, List |
|
|
|
import matplotlib |
|
import numpy as np |
|
import torch |
|
from batchgenerators.utilities.file_and_folder_operations import * |
|
from torch import nn |
|
from torch.optim import lr_scheduler |
|
|
|
import nnunet |
|
from nnunet.configuration import default_num_threads |
|
from nnunet.evaluation.evaluator import aggregate_scores |
|
from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax |
|
from nnunet.network_architecture.generic_UNet import Generic_UNet |
|
from nnunet.network_architecture.initialization import InitWeights_He |
|
from nnunet.network_architecture.neural_network import SegmentationNetwork |
|
from nnunet.postprocessing.connected_components import determine_postprocessing |
|
from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \ |
|
default_2D_augmentation_params, get_default_augmentation, get_patch_size |
|
from nnunet.training.dataloading.dataset_loading import load_dataset, DataLoader3D, DataLoader2D, unpack_dataset |
|
from nnunet.training.loss_functions.dice_loss import DC_and_CE_loss |
|
from nnunet.training.network_training.network_trainer import NetworkTrainer |
|
from nnunet.utilities.nd_softmax import softmax_helper |
|
from nnunet.utilities.tensor_utilities import sum_tensor |
|
|
|
matplotlib.use("agg") |
|
|
|
|
|
class nnUNetTrainer(NetworkTrainer): |
|
def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, |
|
unpack_data=True, deterministic=True, fp16=False): |
|
""" |
|
:param deterministic: |
|
:param fold: can be either [0 ... 5) for cross-validation, 'all' to train on all available training data or |
|
None if you wish to load some checkpoint and do inference only |
|
:param plans_file: the pkl file generated by preprocessing. This file will determine all design choices |
|
:param subfolder_with_preprocessed_data: must be a subfolder of dataset_directory (just the name of the folder, |
|
not the entire path). This is where the preprocessed data lies that will be used for network training. We made |
|
this explicitly available so that differently preprocessed data can coexist and the user can choose what to use. |
|
Can be None if you are doing inference only. |
|
:param output_folder: where to store parameters, plot progress and to the validation |
|
:param dataset_directory: the parent directory in which the preprocessed Task data is stored. This is required |
|
because the split information is stored in this directory. For running prediction only this input is not |
|
required and may be set to None |
|
:param batch_dice: compute dice loss for each sample and average over all samples in the batch or pretend the |
|
batch is a pseudo volume? |
|
:param stage: The plans file may contain several stages (used for lowres / highres / pyramid). Stage must be |
|
specified for training: |
|
if stage 1 exists then stage 1 is the high resolution stage, otherwise it's 0 |
|
:param unpack_data: if False, npz preprocessed data will not be unpacked to npy. This consumes less space but |
|
is considerably slower! Running unpack_data=False with 2d should never be done! |
|
|
|
IMPORTANT: If you inherit from nnUNetTrainer and the init args change then you need to redefine self.init_args |
|
in your init accordingly. Otherwise checkpoints won't load properly! |
|
""" |
|
super(nnUNetTrainer, self).__init__(deterministic, fp16) |
|
self.unpack_data = unpack_data |
|
self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, |
|
deterministic, fp16) |
|
|
|
self.stage = stage |
|
self.experiment_name = self.__class__.__name__ |
|
self.plans_file = plans_file |
|
self.output_folder = output_folder |
|
self.dataset_directory = dataset_directory |
|
self.output_folder_base = self.output_folder |
|
self.fold = fold |
|
|
|
self.plans = None |
|
|
|
|
|
|
|
if self.dataset_directory is not None and isdir(self.dataset_directory): |
|
self.gt_niftis_folder = join(self.dataset_directory, "gt_segmentations") |
|
else: |
|
self.gt_niftis_folder = None |
|
|
|
self.folder_with_preprocessed_data = None |
|
|
|
|
|
|
|
self.dl_tr = self.dl_val = None |
|
self.num_input_channels = self.num_classes = self.net_pool_per_axis = self.patch_size = self.batch_size = \ |
|
self.threeD = self.base_num_features = self.intensity_properties = self.normalization_schemes = \ |
|
self.net_num_pool_op_kernel_sizes = self.net_conv_kernel_sizes = None |
|
self.basic_generator_patch_size = self.data_aug_params = self.transpose_forward = self.transpose_backward = None |
|
|
|
self.batch_dice = batch_dice |
|
self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {}) |
|
|
|
self.online_eval_foreground_dc = [] |
|
self.online_eval_tp = [] |
|
self.online_eval_fp = [] |
|
self.online_eval_fn = [] |
|
|
|
self.classes = self.do_dummy_2D_aug = self.use_mask_for_norm = self.only_keep_largest_connected_component = \ |
|
self.min_region_size_per_class = self.min_size_per_class = None |
|
|
|
self.inference_pad_border_mode = "constant" |
|
self.inference_pad_kwargs = {'constant_values': 0} |
|
|
|
self.update_fold(fold) |
|
self.pad_all_sides = None |
|
|
|
self.lr_scheduler_eps = 1e-3 |
|
self.lr_scheduler_patience = 30 |
|
self.initial_lr = 3e-4 |
|
self.weight_decay = 3e-5 |
|
|
|
self.oversample_foreground_percent = 0.33 |
|
|
|
self.conv_per_stage = None |
|
self.regions_class_order = None |
|
|
|
def update_fold(self, fold): |
|
""" |
|
used to swap between folds for inference (ensemble of models from cross-validation) |
|
DO NOT USE DURING TRAINING AS THIS WILL NOT UPDATE THE DATASET SPLIT AND THE DATA AUGMENTATION GENERATORS |
|
:param fold: |
|
:return: |
|
""" |
|
if fold is not None: |
|
if isinstance(fold, str): |
|
assert fold == "all", "if self.fold is a string then it must be \'all\'" |
|
if self.output_folder.endswith("%s" % str(self.fold)): |
|
self.output_folder = self.output_folder_base |
|
self.output_folder = join(self.output_folder, "%s" % str(fold)) |
|
else: |
|
if self.output_folder.endswith("fold_%s" % str(self.fold)): |
|
self.output_folder = self.output_folder_base |
|
self.output_folder = join(self.output_folder, "fold_%s" % str(fold)) |
|
self.fold = fold |
|
|
|
def setup_DA_params(self): |
|
if self.threeD: |
|
self.data_aug_params = default_3D_augmentation_params |
|
if self.do_dummy_2D_aug: |
|
self.data_aug_params["dummy_2D"] = True |
|
self.print_to_log_file("Using dummy2d data augmentation") |
|
self.data_aug_params["elastic_deform_alpha"] = \ |
|
default_2D_augmentation_params["elastic_deform_alpha"] |
|
self.data_aug_params["elastic_deform_sigma"] = \ |
|
default_2D_augmentation_params["elastic_deform_sigma"] |
|
self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"] |
|
else: |
|
self.do_dummy_2D_aug = False |
|
if max(self.patch_size) / min(self.patch_size) > 1.5: |
|
default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi) |
|
self.data_aug_params = default_2D_augmentation_params |
|
self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm |
|
|
|
if self.do_dummy_2D_aug: |
|
self.basic_generator_patch_size = get_patch_size(self.patch_size[1:], |
|
self.data_aug_params['rotation_x'], |
|
self.data_aug_params['rotation_y'], |
|
self.data_aug_params['rotation_z'], |
|
self.data_aug_params['scale_range']) |
|
self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size)) |
|
else: |
|
self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'], |
|
self.data_aug_params['rotation_y'], |
|
self.data_aug_params['rotation_z'], |
|
self.data_aug_params['scale_range']) |
|
|
|
self.data_aug_params['selected_seg_channels'] = [0] |
|
self.data_aug_params['patch_size_for_spatialtransform'] = self.patch_size |
|
|
|
def initialize(self, training=True, force_load_plans=False): |
|
""" |
|
For prediction of test cases just set training=False, this will prevent loading of training data and |
|
training batchgenerator initialization |
|
:param training: |
|
:return: |
|
""" |
|
|
|
maybe_mkdir_p(self.output_folder) |
|
|
|
if force_load_plans or (self.plans is None): |
|
self.load_plans_file() |
|
|
|
self.process_plans(self.plans) |
|
|
|
self.setup_DA_params() |
|
|
|
if training: |
|
self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + |
|
"_stage%d" % self.stage) |
|
|
|
self.dl_tr, self.dl_val = self.get_basic_generators() |
|
if self.unpack_data: |
|
self.print_to_log_file("unpacking dataset") |
|
unpack_dataset(self.folder_with_preprocessed_data) |
|
self.print_to_log_file("done") |
|
else: |
|
self.print_to_log_file( |
|
"INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " |
|
"will wait all winter for your model to finish!") |
|
self.tr_gen, self.val_gen = get_default_augmentation(self.dl_tr, self.dl_val, |
|
self.data_aug_params[ |
|
'patch_size_for_spatialtransform'], |
|
self.data_aug_params) |
|
self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), |
|
also_print_to_console=False) |
|
self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), |
|
also_print_to_console=False) |
|
else: |
|
pass |
|
self.initialize_network() |
|
self.initialize_optimizer_and_scheduler() |
|
|
|
self.was_initialized = True |
|
|
|
def initialize_network(self): |
|
""" |
|
This is specific to the U-Net and must be adapted for other network architectures |
|
:return: |
|
""" |
|
|
|
|
|
|
|
net_numpool = len(self.net_num_pool_op_kernel_sizes) |
|
|
|
if self.threeD: |
|
conv_op = nn.Conv3d |
|
dropout_op = nn.Dropout3d |
|
norm_op = nn.InstanceNorm3d |
|
else: |
|
conv_op = nn.Conv2d |
|
dropout_op = nn.Dropout2d |
|
norm_op = nn.InstanceNorm2d |
|
|
|
norm_op_kwargs = {'eps': 1e-5, 'affine': True} |
|
dropout_op_kwargs = {'p': 0, 'inplace': True} |
|
net_nonlin = nn.LeakyReLU |
|
net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} |
|
self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes, net_numpool, |
|
self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, |
|
dropout_op_kwargs, |
|
net_nonlin, net_nonlin_kwargs, False, False, lambda x: x, InitWeights_He(1e-2), |
|
self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True) |
|
self.network.inference_apply_nonlin = softmax_helper |
|
|
|
if torch.cuda.is_available(): |
|
self.network.cuda() |
|
|
|
def initialize_optimizer_and_scheduler(self): |
|
assert self.network is not None, "self.initialize_network must be called first" |
|
self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, |
|
amsgrad=True) |
|
self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2, |
|
patience=self.lr_scheduler_patience, |
|
verbose=True, threshold=self.lr_scheduler_eps, |
|
threshold_mode="abs") |
|
|
|
def plot_network_architecture(self): |
|
try: |
|
from batchgenerators.utilities.file_and_folder_operations import join |
|
import hiddenlayer as hl |
|
if torch.cuda.is_available(): |
|
g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)).cuda(), |
|
transforms=None) |
|
else: |
|
g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)), |
|
transforms=None) |
|
g.save(join(self.output_folder, "network_architecture.pdf")) |
|
del g |
|
except Exception as e: |
|
self.print_to_log_file("Unable to plot network architecture:") |
|
self.print_to_log_file(e) |
|
|
|
self.print_to_log_file("\nprinting the network instead:\n") |
|
self.print_to_log_file(self.network) |
|
self.print_to_log_file("\n") |
|
finally: |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
def save_debug_information(self): |
|
|
|
dct = OrderedDict() |
|
for k in self.__dir__(): |
|
if not k.startswith("__"): |
|
if not callable(getattr(self, k)): |
|
dct[k] = str(getattr(self, k)) |
|
del dct['plans'] |
|
del dct['intensity_properties'] |
|
del dct['dataset'] |
|
del dct['dataset_tr'] |
|
del dct['dataset_val'] |
|
save_json(dct, join(self.output_folder, "debug.json")) |
|
|
|
import shutil |
|
|
|
shutil.copy(self.plans_file, join(self.output_folder_base, "plans.pkl")) |
|
|
|
def run_training(self): |
|
self.save_debug_information() |
|
super(nnUNetTrainer, self).run_training() |
|
|
|
def load_plans_file(self): |
|
""" |
|
This is what actually configures the entire experiment. The plans file is generated by experiment planning |
|
:return: |
|
""" |
|
self.plans = load_pickle(self.plans_file) |
|
|
|
def process_plans(self, plans): |
|
if self.stage is None: |
|
assert len(list(plans['plans_per_stage'].keys())) == 1, \ |
|
"If self.stage is None then there can be only one stage in the plans file. That seems to not be the " \ |
|
"case. Please specify which stage of the cascade must be trained" |
|
self.stage = list(plans['plans_per_stage'].keys())[0] |
|
self.plans = plans |
|
|
|
stage_plans = self.plans['plans_per_stage'][self.stage] |
|
self.batch_size = stage_plans['batch_size'] |
|
self.net_pool_per_axis = stage_plans['num_pool_per_axis'] |
|
self.patch_size = np.array(stage_plans['patch_size']).astype(int) |
|
self.do_dummy_2D_aug = stage_plans['do_dummy_2D_data_aug'] |
|
|
|
if 'pool_op_kernel_sizes' not in stage_plans.keys(): |
|
assert 'num_pool_per_axis' in stage_plans.keys() |
|
self.print_to_log_file("WARNING! old plans file with missing pool_op_kernel_sizes. Attempting to fix it...") |
|
self.net_num_pool_op_kernel_sizes = [] |
|
for i in range(max(self.net_pool_per_axis)): |
|
curr = [] |
|
for j in self.net_pool_per_axis: |
|
if (max(self.net_pool_per_axis) - j) <= i: |
|
curr.append(2) |
|
else: |
|
curr.append(1) |
|
self.net_num_pool_op_kernel_sizes.append(curr) |
|
else: |
|
self.net_num_pool_op_kernel_sizes = stage_plans['pool_op_kernel_sizes'] |
|
|
|
if 'conv_kernel_sizes' not in stage_plans.keys(): |
|
self.print_to_log_file("WARNING! old plans file with missing conv_kernel_sizes. Attempting to fix it...") |
|
self.net_conv_kernel_sizes = [[3] * len(self.net_pool_per_axis)] * (max(self.net_pool_per_axis) + 1) |
|
else: |
|
self.net_conv_kernel_sizes = stage_plans['conv_kernel_sizes'] |
|
|
|
self.pad_all_sides = None |
|
self.intensity_properties = plans['dataset_properties']['intensityproperties'] |
|
self.normalization_schemes = plans['normalization_schemes'] |
|
self.base_num_features = plans['base_num_features'] |
|
self.num_input_channels = plans['num_modalities'] |
|
self.num_classes = [num +1 for num in plans['num_classes']] |
|
self.classes = plans['all_classes'] |
|
self.use_mask_for_norm = plans['use_mask_for_norm'] |
|
self.only_keep_largest_connected_component = plans['keep_only_largest_region'] |
|
self.min_region_size_per_class = plans['min_region_size_per_class'] |
|
self.min_size_per_class = None |
|
|
|
if plans.get('transpose_forward') is None or plans.get('transpose_backward') is None: |
|
print("WARNING! You seem to have data that was preprocessed with a previous version of nnU-Net. " |
|
"You should rerun preprocessing. We will proceed and assume that both transpose_foward " |
|
"and transpose_backward are [0, 1, 2]. If that is not correct then weird things will happen!") |
|
plans['transpose_forward'] = [0, 1, 2] |
|
plans['transpose_backward'] = [0, 1, 2] |
|
self.transpose_forward = plans['transpose_forward'] |
|
self.transpose_backward = plans['transpose_backward'] |
|
|
|
if len(self.patch_size) == 2: |
|
self.threeD = False |
|
elif len(self.patch_size) == 3: |
|
self.threeD = True |
|
else: |
|
raise RuntimeError("invalid patch size in plans file: %s" % str(self.patch_size)) |
|
|
|
if "conv_per_stage" in plans.keys(): |
|
self.conv_per_stage = plans['conv_per_stage'] |
|
else: |
|
self.conv_per_stage = 2 |
|
|
|
def load_dataset(self): |
|
self.dataset = load_dataset(self.folder_with_preprocessed_data) |
|
|
|
def get_basic_generators(self): |
|
self.load_dataset() |
|
self.do_split() |
|
|
|
if self.threeD: |
|
dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size, |
|
False, oversample_foreground_percent=self.oversample_foreground_percent, |
|
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') |
|
dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, False, |
|
oversample_foreground_percent=self.oversample_foreground_percent, |
|
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') |
|
else: |
|
dl_tr = DataLoader2D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size, |
|
oversample_foreground_percent=self.oversample_foreground_percent, |
|
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') |
|
dl_val = DataLoader2D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, |
|
oversample_foreground_percent=self.oversample_foreground_percent, |
|
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') |
|
return dl_tr, dl_val |
|
|
|
def preprocess_patient(self, input_files): |
|
""" |
|
Used to predict new unseen data. Not used for the preprocessing of the training/test data |
|
:param input_files: |
|
:return: |
|
""" |
|
from nnunet.training.model_restore import recursive_find_python_class |
|
preprocessor_name = self.plans.get('preprocessor_name') |
|
if preprocessor_name is None: |
|
if self.threeD: |
|
preprocessor_name = "GenericPreprocessor" |
|
else: |
|
preprocessor_name = "PreprocessorFor2D" |
|
|
|
print("using preprocessor", preprocessor_name) |
|
preprocessor_class = recursive_find_python_class([join(nnunet.__path__[0], "preprocessing")], |
|
preprocessor_name, |
|
current_module="nnunet.preprocessing") |
|
assert preprocessor_class is not None, "Could not find preprocessor %s in nnunet.preprocessing" % \ |
|
preprocessor_name |
|
preprocessor = preprocessor_class(self.normalization_schemes, self.use_mask_for_norm, |
|
self.transpose_forward, self.intensity_properties) |
|
|
|
d, s, properties = preprocessor.preprocess_test_case(input_files, |
|
self.plans['plans_per_stage'][self.stage][ |
|
'current_spacing']) |
|
return d, s, properties |
|
|
|
def preprocess_predict_nifti(self, input_files: List[str], output_file: str = None, |
|
softmax_ouput_file: str = None, mixed_precision: bool = True) -> None: |
|
""" |
|
Use this to predict new data |
|
:param input_files: |
|
:param output_file: |
|
:param softmax_ouput_file: |
|
:param mixed_precision: |
|
:return: |
|
""" |
|
print("preprocessing...") |
|
d, s, properties = self.preprocess_patient(input_files) |
|
print("predicting...") |
|
pred = self.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=self.data_aug_params["do_mirror"], |
|
mirror_axes=self.data_aug_params['mirror_axes'], |
|
use_sliding_window=True, step_size=0.5, |
|
use_gaussian=True, pad_border_mode='constant', |
|
pad_kwargs={'constant_values': 0}, |
|
verbose=True, all_in_gpu=False, |
|
mixed_precision=mixed_precision)[1] |
|
pred = pred.transpose([0] + [i + 1 for i in self.transpose_backward]) |
|
|
|
if 'segmentation_export_params' in self.plans.keys(): |
|
force_separate_z = self.plans['segmentation_export_params']['force_separate_z'] |
|
interpolation_order = self.plans['segmentation_export_params']['interpolation_order'] |
|
interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z'] |
|
else: |
|
force_separate_z = None |
|
interpolation_order = 1 |
|
interpolation_order_z = 0 |
|
|
|
print("resampling to original spacing and nifti export...") |
|
save_segmentation_nifti_from_softmax(pred, output_file, properties, interpolation_order, |
|
self.regions_class_order, None, None, softmax_ouput_file, |
|
None, force_separate_z=force_separate_z, |
|
interpolation_order_z=interpolation_order_z) |
|
print("done") |
|
|
|
def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True, |
|
mirror_axes: Tuple[int] = None, |
|
use_sliding_window: bool = True, step_size: float = 0.5, |
|
use_gaussian: bool = True, pad_border_mode: str = 'constant', |
|
pad_kwargs: dict = None, all_in_gpu: bool = False, |
|
verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]: |
|
""" |
|
:param data: |
|
:param do_mirroring: |
|
:param mirror_axes: |
|
:param use_sliding_window: |
|
:param step_size: |
|
:param use_gaussian: |
|
:param pad_border_mode: |
|
:param pad_kwargs: |
|
:param all_in_gpu: |
|
:param verbose: |
|
:return: |
|
""" |
|
if pad_border_mode == 'constant' and pad_kwargs is None: |
|
pad_kwargs = {'constant_values': 0} |
|
|
|
if do_mirroring and mirror_axes is None: |
|
mirror_axes = self.data_aug_params['mirror_axes'] |
|
|
|
if do_mirroring: |
|
assert self.data_aug_params["do_mirror"], "Cannot do mirroring as test time augmentation when training " \ |
|
"was done without mirroring" |
|
|
|
valid = list((SegmentationNetwork, nn.DataParallel)) |
|
assert isinstance(self.network, tuple(valid)) |
|
|
|
current_mode = self.network.training |
|
self.network.eval() |
|
ret = self.network.predict_3D(data, do_mirroring=do_mirroring, mirror_axes=mirror_axes, |
|
use_sliding_window=use_sliding_window, step_size=step_size, |
|
patch_size=self.patch_size, regions_class_order=self.regions_class_order, |
|
use_gaussian=use_gaussian, pad_border_mode=pad_border_mode, |
|
pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose, |
|
mixed_precision=mixed_precision) |
|
self.network.train(current_mode) |
|
return ret |
|
|
|
def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5, |
|
save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True, |
|
validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False, |
|
segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True): |
|
""" |
|
if debug=True then the temporary files generated for postprocessing determination will be kept |
|
""" |
|
|
|
current_mode = self.network.training |
|
self.network.eval() |
|
|
|
assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)" |
|
if self.dataset_val is None: |
|
self.load_dataset() |
|
self.do_split() |
|
|
|
if segmentation_export_kwargs is None: |
|
if 'segmentation_export_params' in self.plans.keys(): |
|
force_separate_z = self.plans['segmentation_export_params']['force_separate_z'] |
|
interpolation_order = self.plans['segmentation_export_params']['interpolation_order'] |
|
interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z'] |
|
else: |
|
force_separate_z = None |
|
interpolation_order = 1 |
|
interpolation_order_z = 0 |
|
else: |
|
force_separate_z = segmentation_export_kwargs['force_separate_z'] |
|
interpolation_order = segmentation_export_kwargs['interpolation_order'] |
|
interpolation_order_z = segmentation_export_kwargs['interpolation_order_z'] |
|
|
|
|
|
output_folder = join(self.output_folder, validation_folder_name) |
|
maybe_mkdir_p(output_folder) |
|
|
|
my_input_args = {'do_mirroring': do_mirroring, |
|
'use_sliding_window': use_sliding_window, |
|
'step_size': step_size, |
|
'save_softmax': save_softmax, |
|
'use_gaussian': use_gaussian, |
|
'overwrite': overwrite, |
|
'validation_folder_name': validation_folder_name, |
|
'debug': debug, |
|
'all_in_gpu': all_in_gpu, |
|
'segmentation_export_kwargs': segmentation_export_kwargs, |
|
} |
|
save_json(my_input_args, join(output_folder, "validation_args.json")) |
|
|
|
if do_mirroring: |
|
if not self.data_aug_params['do_mirror']: |
|
raise RuntimeError("We did not train with mirroring so you cannot do inference with mirroring enabled") |
|
mirror_axes = self.data_aug_params['mirror_axes'] |
|
else: |
|
mirror_axes = () |
|
|
|
pred_gt_tuples = [] |
|
|
|
export_pool = Pool(default_num_threads) |
|
results = [] |
|
|
|
for q, k in enumerate(self.dataset_val.keys()): |
|
print("{}/{}".format(q+1,len(self.dataset_val))) |
|
properties = load_pickle(self.dataset[k]['properties_file']) |
|
fname = properties['list_of_data_files'][0].split("/")[-1][:-12] |
|
if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \ |
|
(save_softmax and not isfile(join(output_folder, fname + ".npz"))): |
|
data = np.load(self.dataset[k]['data_file'])['data'] |
|
|
|
print(k, data.shape) |
|
data[-1][data[-1] == -1] = 0 |
|
|
|
softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data[:1], |
|
do_mirroring=do_mirroring, |
|
mirror_axes=mirror_axes, |
|
use_sliding_window=use_sliding_window, |
|
step_size=step_size, |
|
use_gaussian=use_gaussian, |
|
all_in_gpu=all_in_gpu, |
|
mixed_precision=self.fp16)[1] |
|
|
|
softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in self.transpose_backward]) |
|
|
|
if save_softmax: |
|
softmax_fname = join(output_folder, fname + ".npz") |
|
else: |
|
softmax_fname = None |
|
|
|
"""There is a problem with python process communication that prevents us from communicating obejcts |
|
larger than 2 GB between processes (basically when the length of the pickle string that will be sent is |
|
communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long |
|
enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually |
|
patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will |
|
then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either |
|
filename or np.ndarray and will handle this automatically""" |
|
if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): |
|
np.save(join(output_folder, fname + ".npy"), softmax_pred) |
|
softmax_pred = join(output_folder, fname + ".npy") |
|
|
|
""" |
|
resu = save_segmentation_nifti_from_softmax(softmax_pred, join(output_folder, fname + ".nii.gz"), |
|
properties, interpolation_order, self.regions_class_order, |
|
None, None, |
|
softmax_fname, None, force_separate_z, |
|
interpolation_order_z) |
|
results.append(resu) |
|
|
|
# this eats RAM |
|
""" |
|
results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax, |
|
((softmax_pred, join(output_folder, fname + ".nii.gz"), |
|
properties, interpolation_order, self.regions_class_order, |
|
None, None, |
|
softmax_fname, None, force_separate_z, |
|
interpolation_order_z), |
|
) |
|
) |
|
) |
|
|
|
|
|
|
|
pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"), |
|
join(self.gt_niftis_folder, fname + ".nii.gz")]) |
|
|
|
_ = [i.get() for i in results] |
|
self.print_to_log_file("finished prediction") |
|
|
|
|
|
self.print_to_log_file("evaluation of raw predictions") |
|
task = self.dataset_directory.split("/")[-1] |
|
job_name = self.experiment_name |
|
_ = aggregate_scores(pred_gt_tuples, labels=[list(range(label)) for label in self.num_classes], |
|
json_output_file=join(output_folder, "summary.json"), |
|
json_name=job_name + " val tiled %s" % (str(use_sliding_window)), |
|
json_author="Fabian", |
|
json_task=task, num_threads=default_num_threads) |
|
|
|
if run_postprocessing_on_folds: |
|
|
|
|
|
|
|
|
|
self.print_to_log_file("determining postprocessing") |
|
determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name, |
|
final_subf_name=validation_folder_name + "_postprocessed", debug=debug) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gt_nifti_folder = join(self.output_folder_base, "gt_niftis") |
|
maybe_mkdir_p(gt_nifti_folder) |
|
for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"): |
|
success = False |
|
attempts = 0 |
|
e = None |
|
while not success and attempts < 10: |
|
try: |
|
shutil.copy(f, gt_nifti_folder) |
|
success = True |
|
except OSError as e: |
|
attempts += 1 |
|
sleep(1) |
|
if not success: |
|
print("Could not copy gt nifti file %s into folder %s" % (f, gt_nifti_folder)) |
|
if e is not None: |
|
raise e |
|
|
|
self.network.train(current_mode) |
|
|
|
def run_online_evaluation(self, output, target): |
|
with torch.no_grad(): |
|
num_classes = output.shape[1] |
|
output_softmax = softmax_helper(output) |
|
output_seg = output_softmax.argmax(1) |
|
target = target[:, 0] |
|
axes = tuple(range(1, len(target.shape))) |
|
tp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) |
|
fp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) |
|
fn_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) |
|
for c in range(1, num_classes): |
|
tp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target == c).float(), axes=axes) |
|
fp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target != c).float(), axes=axes) |
|
fn_hard[:, c - 1] = sum_tensor((output_seg != c).float() * (target == c).float(), axes=axes) |
|
|
|
tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy() |
|
fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy() |
|
fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy() |
|
|
|
self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))) |
|
self.online_eval_tp.append(list(tp_hard)) |
|
self.online_eval_fp.append(list(fp_hard)) |
|
self.online_eval_fn.append(list(fn_hard)) |
|
|
|
def finish_online_evaluation(self): |
|
self.online_eval_tp = np.sum(self.online_eval_tp, 0) |
|
self.online_eval_fp = np.sum(self.online_eval_fp, 0) |
|
self.online_eval_fn = np.sum(self.online_eval_fn, 0) |
|
|
|
|
|
global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in |
|
zip(self.online_eval_tp, self.online_eval_fp, self.online_eval_fn)] |
|
if not np.isnan(i)] |
|
self.all_val_eval_metrics.append(np.mean(global_dc_per_class)) |
|
|
|
self.print_to_log_file("Average global foreground Dice:", [np.round(i, 4) for i in global_dc_per_class]) |
|
self.print_to_log_file("(interpret this as an estimate for the Dice of the different classes. This is not " |
|
"exact.)") |
|
|
|
self.online_eval_foreground_dc = [] |
|
self.online_eval_tp = [] |
|
self.online_eval_fp = [] |
|
self.online_eval_fn = [] |
|
|
|
def save_checkpoint(self, fname, save_optimizer=True): |
|
super(nnUNetTrainer, self).save_checkpoint(fname, save_optimizer) |
|
info = OrderedDict() |
|
info['init'] = self.init_args |
|
info['name'] = self.__class__.__name__ |
|
info['class'] = str(self.__class__) |
|
info['plans'] = self.plans |
|
|
|
write_pickle(info, fname + ".pkl") |
|
|