# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch from batchgenerators.utilities.file_and_folder_operations import join from nnunet.network_architecture.generic_UNet import Generic_UNet from nnunet.network_architecture.initialization import InitWeights_He from nnunet.network_architecture.neural_network import SegmentationNetwork from nnunet.training.data_augmentation.data_augmentation_insaneDA2 import get_insaneDA_augmentation2 from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \ default_2D_augmentation_params, get_patch_size from nnunet.training.dataloading.dataset_loading import unpack_dataset from nnunet.training.loss_functions.deep_supervision import MultipleOutputLoss2 from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2, maybe_mkdir_p from nnunet.utilities.nd_softmax import softmax_helper from torch import nn class nnUNetTrainerV2_DA3(nnUNetTrainerV2): def setup_DA_params(self): super().setup_DA_params() self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod( np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1] if self.threeD: self.data_aug_params = default_3D_augmentation_params self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) if self.do_dummy_2D_aug: self.data_aug_params["dummy_2D"] = True self.print_to_log_file("Using dummy2d data augmentation") self.data_aug_params["elastic_deform_alpha"] = \ default_2D_augmentation_params["elastic_deform_alpha"] self.data_aug_params["elastic_deform_sigma"] = \ default_2D_augmentation_params["elastic_deform_sigma"] self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"] else: self.do_dummy_2D_aug = False if max(self.patch_size) / min(self.patch_size) > 1.5: default_2D_augmentation_params['rotation_x'] = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi) self.data_aug_params = default_2D_augmentation_params self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm if self.do_dummy_2D_aug: self.basic_generator_patch_size = get_patch_size(self.patch_size[1:], self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size)) else: self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) self.data_aug_params['selected_seg_channels'] = [0] self.data_aug_params['patch_size_for_spatialtransform'] = self.patch_size self.data_aug_params["p_rot"] = 0.3 self.data_aug_params["scale_range"] = (0.65, 1.6) self.data_aug_params["p_scale"] = 0.3 self.data_aug_params["independent_scale_factor_for_each_axis"] = True self.data_aug_params["p_independent_scale_per_axis"] = 0.3 self.data_aug_params["do_elastic"] = True self.data_aug_params["p_eldef"] = 0.3 self.data_aug_params["eldef_deformation_scale"] = (0, 0.25) self.data_aug_params["do_additive_brightness"] = True self.data_aug_params["additive_brightness_mu"] = 0 self.data_aug_params["additive_brightness_sigma"] = 0.2 self.data_aug_params["additive_brightness_p_per_sample"] = 0.3 self.data_aug_params["additive_brightness_p_per_channel"] = 1 self.data_aug_params['gamma_range'] = (0.5, 1.6) self.data_aug_params['num_cached_per_thread'] = 4 def initialize(self, training=True, force_load_plans=False): if not self.was_initialized: maybe_mkdir_p(self.output_folder) if force_load_plans or (self.plans is None): self.load_plans_file() self.process_plans(self.plans) self.setup_DA_params() ################# Here we wrap the loss for deep supervision ############ # we need to know the number of outputs of the network net_numpool = len(self.net_num_pool_op_kernel_sizes) # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss weights = np.array([1 / (2 ** i) for i in range(net_numpool)]) # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)]) weights[~mask] = 0 weights = weights / weights.sum() self.ds_loss_weights = weights # now wrap the loss self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights) ################# END ################### self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) if training: self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: print("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) print("done") else: print( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") self.tr_gen, self.val_gen = get_insaneDA_augmentation2( self.dl_tr, self.dl_val, self.data_aug_params[ 'patch_size_for_spatialtransform'], self.data_aug_params, deep_supervision_scales=self.deep_supervision_scales, pin_memory=self.pin_memory ) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel)) else: self.print_to_log_file('self.was_initialized is True, not running self.initialize again') self.was_initialized = True """def run_training(self): from batchviewer import view_batch a = next(self.tr_gen) view_batch(a['data'][:, 0], width=512, height=512) import IPython;IPython.embed()""" class nnUNetTrainerV2_DA3_BN(nnUNetTrainerV2_DA3): def initialize_network(self): if self.threeD: conv_op = nn.Conv3d dropout_op = nn.Dropout3d norm_op = nn.BatchNorm3d else: conv_op = nn.Conv2d dropout_op = nn.Dropout2d norm_op = nn.BatchNorm2d norm_op_kwargs = {'eps': 1e-5, 'affine': True} dropout_op_kwargs = {'p': 0, 'inplace': True} net_nonlin = nn.LeakyReLU net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes, len(self.net_num_pool_op_kernel_sizes), self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True) if torch.cuda.is_available(): self.network.cuda() self.network.inference_apply_nonlin = softmax_helper