File size: 9,841 Bytes
ecf08bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from batchgenerators.utilities.file_and_folder_operations import join
from nnunet.network_architecture.generic_UNet import Generic_UNet
from nnunet.network_architecture.initialization import InitWeights_He
from nnunet.network_architecture.neural_network import SegmentationNetwork
from nnunet.training.data_augmentation.data_augmentation_insaneDA2 import get_insaneDA_augmentation2
from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \
default_2D_augmentation_params, get_patch_size
from nnunet.training.dataloading.dataset_loading import unpack_dataset
from nnunet.training.loss_functions.deep_supervision import MultipleOutputLoss2
from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2, maybe_mkdir_p
from nnunet.utilities.nd_softmax import softmax_helper
from torch import nn
class nnUNetTrainerV2_DA3(nnUNetTrainerV2):
def setup_DA_params(self):
super().setup_DA_params()
self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod(
np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1]
if self.threeD:
self.data_aug_params = default_3D_augmentation_params
self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
if self.do_dummy_2D_aug:
self.data_aug_params["dummy_2D"] = True
self.print_to_log_file("Using dummy2d data augmentation")
self.data_aug_params["elastic_deform_alpha"] = \
default_2D_augmentation_params["elastic_deform_alpha"]
self.data_aug_params["elastic_deform_sigma"] = \
default_2D_augmentation_params["elastic_deform_sigma"]
self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"]
else:
self.do_dummy_2D_aug = False
if max(self.patch_size) / min(self.patch_size) > 1.5:
default_2D_augmentation_params['rotation_x'] = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)
self.data_aug_params = default_2D_augmentation_params
self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
if self.do_dummy_2D_aug:
self.basic_generator_patch_size = get_patch_size(self.patch_size[1:],
self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size))
else:
self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
self.data_aug_params['selected_seg_channels'] = [0]
self.data_aug_params['patch_size_for_spatialtransform'] = self.patch_size
self.data_aug_params["p_rot"] = 0.3
self.data_aug_params["scale_range"] = (0.65, 1.6)
self.data_aug_params["p_scale"] = 0.3
self.data_aug_params["independent_scale_factor_for_each_axis"] = True
self.data_aug_params["p_independent_scale_per_axis"] = 0.3
self.data_aug_params["do_elastic"] = True
self.data_aug_params["p_eldef"] = 0.3
self.data_aug_params["eldef_deformation_scale"] = (0, 0.25)
self.data_aug_params["do_additive_brightness"] = True
self.data_aug_params["additive_brightness_mu"] = 0
self.data_aug_params["additive_brightness_sigma"] = 0.2
self.data_aug_params["additive_brightness_p_per_sample"] = 0.3
self.data_aug_params["additive_brightness_p_per_channel"] = 1
self.data_aug_params['gamma_range'] = (0.5, 1.6)
self.data_aug_params['num_cached_per_thread'] = 4
def initialize(self, training=True, force_load_plans=False):
if not self.was_initialized:
maybe_mkdir_p(self.output_folder)
if force_load_plans or (self.plans is None):
self.load_plans_file()
self.process_plans(self.plans)
self.setup_DA_params()
################# Here we wrap the loss for deep supervision ############
# we need to know the number of outputs of the network
net_numpool = len(self.net_num_pool_op_kernel_sizes)
# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
# this gives higher resolution outputs more weight in the loss
weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)])
weights[~mask] = 0
weights = weights / weights.sum()
self.ds_loss_weights = weights
# now wrap the loss
self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)
################# END ###################
self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
"_stage%d" % self.stage)
if training:
self.dl_tr, self.dl_val = self.get_basic_generators()
if self.unpack_data:
print("unpacking dataset")
unpack_dataset(self.folder_with_preprocessed_data)
print("done")
else:
print(
"INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
"will wait all winter for your model to finish!")
self.tr_gen, self.val_gen = get_insaneDA_augmentation2(
self.dl_tr, self.dl_val,
self.data_aug_params[
'patch_size_for_spatialtransform'],
self.data_aug_params,
deep_supervision_scales=self.deep_supervision_scales,
pin_memory=self.pin_memory
)
self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
also_print_to_console=False)
self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
also_print_to_console=False)
else:
pass
self.initialize_network()
self.initialize_optimizer_and_scheduler()
assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
else:
self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
self.was_initialized = True
"""def run_training(self):
from batchviewer import view_batch
a = next(self.tr_gen)
view_batch(a['data'][:, 0], width=512, height=512)
import IPython;IPython.embed()"""
class nnUNetTrainerV2_DA3_BN(nnUNetTrainerV2_DA3):
def initialize_network(self):
if self.threeD:
conv_op = nn.Conv3d
dropout_op = nn.Dropout3d
norm_op = nn.BatchNorm3d
else:
conv_op = nn.Conv2d
dropout_op = nn.Dropout2d
norm_op = nn.BatchNorm2d
norm_op_kwargs = {'eps': 1e-5, 'affine': True}
dropout_op_kwargs = {'p': 0, 'inplace': True}
net_nonlin = nn.LeakyReLU
net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
len(self.net_num_pool_op_kernel_sizes),
self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op,
dropout_op_kwargs,
net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
if torch.cuda.is_available():
self.network.cuda()
self.network.inference_apply_nonlin = softmax_helper
|