|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
from typing import Tuple |
|
|
|
import numpy as np |
|
import torch |
|
from nnunet.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation |
|
from nnunet.training.loss_functions.deep_supervision import MultipleOutputLoss2 |
|
from nnunet.utilities.to_torch import maybe_to_torch, to_cuda |
|
from nnunet.network_architecture.generic_UNet_MTLearly import Generic_UNet_MTLearly |
|
from nnunet.network_architecture.initialization import InitWeights_He |
|
from nnunet.network_architecture.neural_network import SegmentationNetwork |
|
from nnunet.training.data_augmentation.default_data_augmentation import default_2D_augmentation_params, \ |
|
get_patch_size, default_3D_augmentation_params |
|
from nnunet.training.dataloading.dataset_loading import unpack_dataset |
|
from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer |
|
from nnunet.utilities.nd_softmax import softmax_helper |
|
from sklearn.model_selection import KFold |
|
from torch import nn |
|
from torch.cuda.amp import autocast |
|
from nnunet.training.learning_rate.poly_lr import poly_lr |
|
from batchgenerators.utilities.file_and_folder_operations import * |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
class nnUNetTrainerMTLearly(nnUNetTrainer): |
|
|
|
|
|
def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, |
|
unpack_data=True, deterministic=True, fp16=False): |
|
super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, |
|
deterministic, fp16) |
|
|
|
self.max_num_epochs = 400 |
|
self.num_batches_per_epoch = 250 |
|
self.num_val_batches_per_epoch = 50 |
|
|
|
self.initial_lr = 1e-2 |
|
self.deep_supervision_scales = None |
|
self.ds_loss_weights = None |
|
|
|
self.pin_memory = True |
|
|
|
def initialize(self, training=True, force_load_plans=False, disable_deepsupervision=False): |
|
""" |
|
- replaced get_default_augmentation with get_moreDA_augmentation |
|
- enforce to only run this code once |
|
- loss function wrapper for deep supervision |
|
|
|
:param training: |
|
:param force_load_plans: |
|
:return: |
|
""" |
|
if not self.was_initialized: |
|
maybe_mkdir_p(self.output_folder) |
|
|
|
if force_load_plans or (self.plans is None): |
|
self.load_plans_file() |
|
|
|
self.process_plans(self.plans) |
|
|
|
|
|
self.setup_DA_params() |
|
|
|
|
|
|
|
net_numpool = len(self.net_num_pool_op_kernel_sizes) |
|
|
|
|
|
|
|
weights = np.array([1 / (2 ** i) for i in range(net_numpool)]) |
|
|
|
|
|
mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)]) |
|
weights[~mask] = 0 |
|
weights = weights / weights.sum() |
|
self.ds_loss_weights = weights |
|
|
|
self.loss_zone = MultipleOutputLoss2(self.loss, self.ds_loss_weights) |
|
|
|
self.loss_front = MultipleOutputLoss2(self.loss, None) |
|
|
|
|
|
self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + |
|
"_stage%d" % self.stage) |
|
if training: |
|
|
|
self.dl_tr, self.dl_val = self.get_basic_generators() |
|
if self.unpack_data: |
|
print("unpacking dataset") |
|
unpack_dataset(self.folder_with_preprocessed_data) |
|
print("done") |
|
else: |
|
print( |
|
"INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " |
|
"will wait all winter for your model to finish!") |
|
|
|
self.tr_gen, self.val_gen = get_moreDA_augmentation( |
|
self.dl_tr, self.dl_val, |
|
self.data_aug_params[ |
|
'patch_size_for_spatialtransform'], |
|
self.data_aug_params, |
|
deep_supervision_scales=self.deep_supervision_scales, |
|
pin_memory=self.pin_memory, |
|
use_nondetMultiThreadedAugmenter=False |
|
) |
|
self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), |
|
also_print_to_console=False) |
|
self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), |
|
also_print_to_console=False) |
|
else: |
|
pass |
|
|
|
self.initialize_network() |
|
self.initialize_optimizer_and_scheduler() |
|
|
|
assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel)) |
|
else: |
|
self.print_to_log_file('self.was_initialized is True, not running self.initialize again') |
|
self.was_initialized = True |
|
|
|
def initialize_network(self): |
|
""" |
|
- momentum 0.99 |
|
- SGD instead of Adam |
|
- self.lr_scheduler = None because we do poly_lr |
|
- deep supervision = True |
|
- i am sure I forgot something here |
|
|
|
Known issue: forgot to set neg_slope=0 in InitWeights_He; should not make a difference though |
|
:return: |
|
""" |
|
if self.threeD: |
|
conv_op = nn.Conv3d |
|
dropout_op = nn.Dropout3d |
|
norm_op = nn.InstanceNorm3d |
|
|
|
else: |
|
conv_op = nn.Conv2d |
|
dropout_op = nn.Dropout2d |
|
norm_op = nn.InstanceNorm2d |
|
|
|
norm_op_kwargs = {'eps': 1e-5, 'affine': True} |
|
dropout_op_kwargs = {'p': 0, 'inplace': True} |
|
net_nonlin = nn.LeakyReLU |
|
net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} |
|
|
|
self.network = Generic_UNet_MTLearly(self.num_input_channels, self.base_num_features, self.num_classes, |
|
len(self.net_num_pool_op_kernel_sizes), |
|
self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, |
|
dropout_op_kwargs, |
|
net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2), |
|
self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True) |
|
|
|
print("number of parameters") |
|
print(sum(param.numel() for param in self.network.parameters())) |
|
|
|
if torch.cuda.is_available(): |
|
self.network.cuda() |
|
self.network.inference_apply_nonlin = softmax_helper |
|
|
|
def initialize_optimizer_and_scheduler(self): |
|
assert self.network is not None, "self.initialize_network must be called first" |
|
self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, |
|
momentum=0.99, nesterov=True) |
|
self.lr_scheduler = None |
|
|
|
def run_online_evaluation(self, output, target): |
|
""" |
|
due to deep supervision the return value and the reference are now lists of tensors. We only need the full |
|
resolution output because this is what we are interested in in the end. The others are ignored |
|
:param output: |
|
:param target: |
|
:return: |
|
""" |
|
target = target[0] |
|
output = output[0] |
|
return super().run_online_evaluation(output, target) |
|
|
|
def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, |
|
step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True, |
|
validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False, |
|
segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True): |
|
""" |
|
We need to wrap this because we need to enforce self.network.do_ds = False for prediction |
|
""" |
|
ds = self.network.do_ds |
|
self.network.do_ds = False |
|
ret = super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size, |
|
save_softmax=save_softmax, use_gaussian=use_gaussian, |
|
overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug, |
|
all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs, |
|
run_postprocessing_on_folds=run_postprocessing_on_folds) |
|
|
|
self.network.do_ds = ds |
|
return ret |
|
|
|
def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True, |
|
mirror_axes: Tuple[int] = None, |
|
use_sliding_window: bool = True, step_size: float = 0.5, |
|
use_gaussian: bool = True, pad_border_mode: str = 'constant', |
|
pad_kwargs: dict = None, all_in_gpu: bool = False, |
|
verbose: bool = True, mixed_precision=True) -> Tuple[np.ndarray, np.ndarray]: |
|
""" |
|
We need to wrap this because we need to enforce self.network.do_ds = False for prediction |
|
""" |
|
ds = self.network.do_ds |
|
self.network.do_ds = False |
|
ret = super().predict_preprocessed_data_return_seg_and_softmax(data[:1], |
|
do_mirroring=do_mirroring, |
|
mirror_axes=mirror_axes, |
|
use_sliding_window=use_sliding_window, |
|
step_size=step_size, use_gaussian=use_gaussian, |
|
pad_border_mode=pad_border_mode, |
|
pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, |
|
verbose=verbose, |
|
mixed_precision=mixed_precision) |
|
self.network.do_ds = ds |
|
return ret |
|
|
|
def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): |
|
""" |
|
gradient clipping improves training stability |
|
|
|
:param data_generator: |
|
:param do_backprop: |
|
:param run_online_evaluation: |
|
:return: |
|
""" |
|
data_dict = next(data_generator) |
|
data = data_dict['data'] |
|
target = data_dict['target'] |
|
|
|
data = maybe_to_torch(data) |
|
target = maybe_to_torch(target) |
|
|
|
if torch.cuda.is_available(): |
|
data = to_cuda(data) |
|
target = to_cuda(target) |
|
|
|
self.optimizer.zero_grad() |
|
|
|
if self.fp16: |
|
with autocast(): |
|
|
|
output = self.network(data) |
|
del data |
|
|
|
target0 = [target[i][:, [0]].to(torch.float16) for i in range(len(target))] |
|
output0 = [output[i][:, 0:2] for i in range(len(output))] |
|
|
|
target1 = [target[i][:, [1]].to(torch.float16) for i in range(len(target))] |
|
output1 = [output[i][:, 2:6] for i in range(len(output))] |
|
|
|
l_front = self.loss_front(output0, target0) |
|
|
|
l_zone = self.loss_zone(output1, target1) |
|
l = 0.5*l_front + 0.5*l_zone |
|
|
|
if do_backprop: |
|
self.amp_grad_scaler.scale(l).backward() |
|
self.amp_grad_scaler.unscale_(self.optimizer) |
|
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) |
|
self.amp_grad_scaler.step(self.optimizer) |
|
self.amp_grad_scaler.update() |
|
else: |
|
output = self.network(data) |
|
del data |
|
l = self.loss(output, target) |
|
|
|
if do_backprop: |
|
l.backward() |
|
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) |
|
self.optimizer.step() |
|
|
|
if run_online_evaluation: |
|
self.run_online_evaluation(output0, target0) |
|
self.run_online_evaluation(output1, target1) |
|
|
|
del target |
|
del target0 |
|
del target1 |
|
|
|
del output |
|
del output0 |
|
del output1 |
|
|
|
return l.detach().cpu().numpy() |
|
|
|
def do_split(self): |
|
""" |
|
The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded, |
|
so always the same) and save it as splits_final.pkl file in the preprocessed data directory. |
|
Sometimes you may want to create your own split for various reasons. For this you will need to create your own |
|
splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in |
|
it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3) |
|
and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to |
|
use a random 80:20 data split. |
|
:return: |
|
""" |
|
if self.fold == "all": |
|
|
|
tr_keys = val_keys = list(self.dataset.keys()) |
|
else: |
|
splits_file = join(self.dataset_directory, "splits_final.pkl") |
|
|
|
|
|
if not isfile(splits_file): |
|
self.print_to_log_file("Creating new 5-fold cross-validation split...") |
|
splits = [] |
|
all_keys_sorted = np.sort(list(self.dataset.keys())) |
|
kfold = KFold(n_splits=5, shuffle=True, random_state=12345) |
|
for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)): |
|
train_keys = np.array(all_keys_sorted)[train_idx] |
|
test_keys = np.array(all_keys_sorted)[test_idx] |
|
splits.append(OrderedDict()) |
|
splits[-1]['train'] = train_keys |
|
splits[-1]['val'] = test_keys |
|
save_pickle(splits, splits_file) |
|
|
|
else: |
|
self.print_to_log_file("Using splits from existing split file:", splits_file) |
|
splits = load_pickle(splits_file) |
|
self.print_to_log_file("The split file contains %d splits." % len(splits)) |
|
|
|
self.print_to_log_file("Desired fold for training: %d" % self.fold) |
|
if self.fold < len(splits): |
|
tr_keys = splits[self.fold]['train'] |
|
val_keys = splits[self.fold]['val'] |
|
self.print_to_log_file("This split has %d training and %d validation cases." |
|
% (len(tr_keys), len(val_keys))) |
|
else: |
|
self.print_to_log_file("INFO: You requested fold %d for training but splits " |
|
"contain only %d folds. I am now creating a " |
|
"random (but seeded) 80:20 split!" % (self.fold, len(splits))) |
|
|
|
rnd = np.random.RandomState(seed=12345 + self.fold) |
|
keys = np.sort(list(self.dataset.keys())) |
|
idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False) |
|
idx_val = [i for i in range(len(keys)) if i not in idx_tr] |
|
tr_keys = [keys[i] for i in idx_tr] |
|
val_keys = [keys[i] for i in idx_val] |
|
self.print_to_log_file("This random 80:20 split has %d training and %d validation cases." |
|
% (len(tr_keys), len(val_keys))) |
|
|
|
tr_keys.sort() |
|
val_keys.sort() |
|
self.dataset_tr = OrderedDict() |
|
for i in tr_keys: |
|
self.dataset_tr[i] = self.dataset[i] |
|
self.dataset_val = OrderedDict() |
|
for i in val_keys: |
|
self.dataset_val[i] = self.dataset[i] |
|
|
|
def setup_DA_params(self): |
|
""" |
|
- we increase roation angle from [-15, 15] to [-30, 30] |
|
- scale range is now (0.7, 1.4), was (0.85, 1.25) |
|
- we don't do elastic deformation anymore |
|
|
|
:return: |
|
""" |
|
|
|
self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod( |
|
np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1] |
|
|
|
if self.threeD: |
|
self.data_aug_params = default_3D_augmentation_params |
|
self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) |
|
self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) |
|
self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) |
|
if self.do_dummy_2D_aug: |
|
self.data_aug_params["dummy_2D"] = True |
|
self.print_to_log_file("Using dummy2d data augmentation") |
|
self.data_aug_params["elastic_deform_alpha"] = \ |
|
default_2D_augmentation_params["elastic_deform_alpha"] |
|
self.data_aug_params["elastic_deform_sigma"] = \ |
|
default_2D_augmentation_params["elastic_deform_sigma"] |
|
self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"] |
|
else: |
|
self.do_dummy_2D_aug = False |
|
if max(self.patch_size) / min(self.patch_size) > 1.5: |
|
default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi) |
|
self.data_aug_params = default_2D_augmentation_params |
|
self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm |
|
|
|
if self.do_dummy_2D_aug: |
|
self.basic_generator_patch_size = get_patch_size(self.patch_size[1:], |
|
self.data_aug_params['rotation_x'], |
|
self.data_aug_params['rotation_y'], |
|
self.data_aug_params['rotation_z'], |
|
self.data_aug_params['scale_range']) |
|
self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size)) |
|
else: |
|
self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'], |
|
self.data_aug_params['rotation_y'], |
|
self.data_aug_params['rotation_z'], |
|
self.data_aug_params['scale_range']) |
|
|
|
self.data_aug_params["scale_range"] = (0.7, 1.4) |
|
self.data_aug_params["do_elastic"] = False |
|
self.data_aug_params['selected_seg_channels'] = [0, 1] |
|
self.data_aug_params['patch_size_for_spatialtransform'] = self.patch_size |
|
|
|
self.data_aug_params["num_cached_per_thread"] = 2 |
|
|
|
def maybe_update_lr(self, epoch=None): |
|
""" |
|
if epoch is not None we overwrite epoch. Else we use epoch = self.epoch + 1 |
|
|
|
(maybe_update_lr is called in on_epoch_end which is called before epoch is incremented. |
|
herefore we need to do +1 here) |
|
|
|
:param epoch: |
|
:return: |
|
""" |
|
if epoch is None: |
|
ep = self.epoch + 1 |
|
else: |
|
ep = epoch |
|
self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9) |
|
self.print_to_log_file("lr:", np.round(self.optimizer.param_groups[0]['lr'], decimals=6)) |
|
|
|
def on_epoch_end(self): |
|
""" |
|
overwrite patient-based early stopping. Always run to 1000 epochs |
|
:return: |
|
""" |
|
super().on_epoch_end() |
|
continue_training = self.epoch < self.max_num_epochs |
|
|
|
|
|
|
|
if self.epoch == 100: |
|
if self.all_val_eval_metrics[-1] == 0: |
|
self.optimizer.param_groups[0]["momentum"] = 0.95 |
|
self.network.apply(InitWeights_He(1e-2)) |
|
self.print_to_log_file("At epoch 100, the mean foreground Dice was 0. This can be caused by a too " |
|
"high momentum. High momentum (0.99) is good for datasets where it works, but " |
|
"sometimes causes issues such as this one. Momentum has now been reduced to " |
|
"0.95 and network weights have been reinitialized") |
|
return continue_training |
|
|
|
def run_training(self): |
|
""" |
|
if we run with -c then we need to set the correct lr for the first epoch, otherwise it will run the first |
|
continued epoch with self.initial_lr |
|
|
|
we also need to make sure deep supervision in the network is enabled for training, thus the wrapper |
|
:return: |
|
""" |
|
self.maybe_update_lr(self.epoch) |
|
|
|
ds = self.network.do_ds |
|
self.network.do_ds = True |
|
ret = super().run_training() |
|
self.network.do_ds = ds |
|
return ret |
|
|