nnUNet_calvingfront_detection
/
nnunet
/training
/network_training
/nnUNet_variants
/data_augmentation
/nnUNetTrainerV2_DA3.py
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import numpy as np | |
import torch | |
from batchgenerators.utilities.file_and_folder_operations import join | |
from nnunet.network_architecture.generic_UNet import Generic_UNet | |
from nnunet.network_architecture.initialization import InitWeights_He | |
from nnunet.network_architecture.neural_network import SegmentationNetwork | |
from nnunet.training.data_augmentation.data_augmentation_insaneDA2 import get_insaneDA_augmentation2 | |
from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \ | |
default_2D_augmentation_params, get_patch_size | |
from nnunet.training.dataloading.dataset_loading import unpack_dataset | |
from nnunet.training.loss_functions.deep_supervision import MultipleOutputLoss2 | |
from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2, maybe_mkdir_p | |
from nnunet.utilities.nd_softmax import softmax_helper | |
from torch import nn | |
class nnUNetTrainerV2_DA3(nnUNetTrainerV2): | |
def setup_DA_params(self): | |
super().setup_DA_params() | |
self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod( | |
np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1] | |
if self.threeD: | |
self.data_aug_params = default_3D_augmentation_params | |
self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) | |
self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) | |
self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) | |
if self.do_dummy_2D_aug: | |
self.data_aug_params["dummy_2D"] = True | |
self.print_to_log_file("Using dummy2d data augmentation") | |
self.data_aug_params["elastic_deform_alpha"] = \ | |
default_2D_augmentation_params["elastic_deform_alpha"] | |
self.data_aug_params["elastic_deform_sigma"] = \ | |
default_2D_augmentation_params["elastic_deform_sigma"] | |
self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"] | |
else: | |
self.do_dummy_2D_aug = False | |
if max(self.patch_size) / min(self.patch_size) > 1.5: | |
default_2D_augmentation_params['rotation_x'] = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi) | |
self.data_aug_params = default_2D_augmentation_params | |
self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm | |
if self.do_dummy_2D_aug: | |
self.basic_generator_patch_size = get_patch_size(self.patch_size[1:], | |
self.data_aug_params['rotation_x'], | |
self.data_aug_params['rotation_y'], | |
self.data_aug_params['rotation_z'], | |
self.data_aug_params['scale_range']) | |
self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size)) | |
else: | |
self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'], | |
self.data_aug_params['rotation_y'], | |
self.data_aug_params['rotation_z'], | |
self.data_aug_params['scale_range']) | |
self.data_aug_params['selected_seg_channels'] = [0] | |
self.data_aug_params['patch_size_for_spatialtransform'] = self.patch_size | |
self.data_aug_params["p_rot"] = 0.3 | |
self.data_aug_params["scale_range"] = (0.65, 1.6) | |
self.data_aug_params["p_scale"] = 0.3 | |
self.data_aug_params["independent_scale_factor_for_each_axis"] = True | |
self.data_aug_params["p_independent_scale_per_axis"] = 0.3 | |
self.data_aug_params["do_elastic"] = True | |
self.data_aug_params["p_eldef"] = 0.3 | |
self.data_aug_params["eldef_deformation_scale"] = (0, 0.25) | |
self.data_aug_params["do_additive_brightness"] = True | |
self.data_aug_params["additive_brightness_mu"] = 0 | |
self.data_aug_params["additive_brightness_sigma"] = 0.2 | |
self.data_aug_params["additive_brightness_p_per_sample"] = 0.3 | |
self.data_aug_params["additive_brightness_p_per_channel"] = 1 | |
self.data_aug_params['gamma_range'] = (0.5, 1.6) | |
self.data_aug_params['num_cached_per_thread'] = 4 | |
def initialize(self, training=True, force_load_plans=False): | |
if not self.was_initialized: | |
maybe_mkdir_p(self.output_folder) | |
if force_load_plans or (self.plans is None): | |
self.load_plans_file() | |
self.process_plans(self.plans) | |
self.setup_DA_params() | |
################# Here we wrap the loss for deep supervision ############ | |
# we need to know the number of outputs of the network | |
net_numpool = len(self.net_num_pool_op_kernel_sizes) | |
# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases | |
# this gives higher resolution outputs more weight in the loss | |
weights = np.array([1 / (2 ** i) for i in range(net_numpool)]) | |
# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 | |
mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)]) | |
weights[~mask] = 0 | |
weights = weights / weights.sum() | |
self.ds_loss_weights = weights | |
# now wrap the loss | |
self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights) | |
################# END ################### | |
self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + | |
"_stage%d" % self.stage) | |
if training: | |
self.dl_tr, self.dl_val = self.get_basic_generators() | |
if self.unpack_data: | |
print("unpacking dataset") | |
unpack_dataset(self.folder_with_preprocessed_data) | |
print("done") | |
else: | |
print( | |
"INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " | |
"will wait all winter for your model to finish!") | |
self.tr_gen, self.val_gen = get_insaneDA_augmentation2( | |
self.dl_tr, self.dl_val, | |
self.data_aug_params[ | |
'patch_size_for_spatialtransform'], | |
self.data_aug_params, | |
deep_supervision_scales=self.deep_supervision_scales, | |
pin_memory=self.pin_memory | |
) | |
self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), | |
also_print_to_console=False) | |
self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), | |
also_print_to_console=False) | |
else: | |
pass | |
self.initialize_network() | |
self.initialize_optimizer_and_scheduler() | |
assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel)) | |
else: | |
self.print_to_log_file('self.was_initialized is True, not running self.initialize again') | |
self.was_initialized = True | |
"""def run_training(self): | |
from batchviewer import view_batch | |
a = next(self.tr_gen) | |
view_batch(a['data'][:, 0], width=512, height=512) | |
import IPython;IPython.embed()""" | |
class nnUNetTrainerV2_DA3_BN(nnUNetTrainerV2_DA3): | |
def initialize_network(self): | |
if self.threeD: | |
conv_op = nn.Conv3d | |
dropout_op = nn.Dropout3d | |
norm_op = nn.BatchNorm3d | |
else: | |
conv_op = nn.Conv2d | |
dropout_op = nn.Dropout2d | |
norm_op = nn.BatchNorm2d | |
norm_op_kwargs = {'eps': 1e-5, 'affine': True} | |
dropout_op_kwargs = {'p': 0, 'inplace': True} | |
net_nonlin = nn.LeakyReLU | |
net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} | |
self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes, | |
len(self.net_num_pool_op_kernel_sizes), | |
self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, | |
dropout_op_kwargs, | |
net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2), | |
self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True) | |
if torch.cuda.is_available(): | |
self.network.cuda() | |
self.network.inference_apply_nonlin = softmax_helper | |