nnUNet_calvingfront_detection
/
nnunet
/training
/network_training
/nnUNet_variants
/architectural_variants
/nnUNetTrainerV2_noDeepSupervision.py
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import numpy as np | |
from batchgenerators.utilities.file_and_folder_operations import * | |
from nnunet.network_architecture.generic_UNet import Generic_UNet | |
from nnunet.network_architecture.initialization import InitWeights_He | |
from nnunet.network_architecture.neural_network import SegmentationNetwork | |
from nnunet.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation | |
from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \ | |
default_2D_augmentation_params, get_patch_size | |
from nnunet.training.dataloading.dataset_loading import unpack_dataset | |
from nnunet.training.loss_functions.dice_loss import DC_and_CE_loss | |
from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer | |
from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2 | |
from nnunet.utilities.nd_softmax import softmax_helper | |
from torch import nn | |
import torch | |
class nnUNetTrainerV2_noDeepSupervision(nnUNetTrainerV2): | |
def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, | |
unpack_data=True, deterministic=True, fp16=False): | |
super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, | |
deterministic, fp16) | |
self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {}) | |
def setup_DA_params(self): | |
""" | |
we leave out the creation of self.deep_supervision_scales, so it remains None | |
:return: | |
""" | |
if self.threeD: | |
self.data_aug_params = default_3D_augmentation_params | |
self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) | |
self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) | |
self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) | |
if self.do_dummy_2D_aug: | |
self.data_aug_params["dummy_2D"] = True | |
self.print_to_log_file("Using dummy2d data augmentation") | |
self.data_aug_params["elastic_deform_alpha"] = \ | |
default_2D_augmentation_params["elastic_deform_alpha"] | |
self.data_aug_params["elastic_deform_sigma"] = \ | |
default_2D_augmentation_params["elastic_deform_sigma"] | |
self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"] | |
else: | |
self.do_dummy_2D_aug = False | |
if max(self.patch_size) / min(self.patch_size) > 1.5: | |
default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi) | |
self.data_aug_params = default_2D_augmentation_params | |
self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm | |
if self.do_dummy_2D_aug: | |
self.basic_generator_patch_size = get_patch_size(self.patch_size[1:], | |
self.data_aug_params['rotation_x'], | |
self.data_aug_params['rotation_y'], | |
self.data_aug_params['rotation_z'], | |
self.data_aug_params['scale_range']) | |
self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size)) | |
else: | |
self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'], | |
self.data_aug_params['rotation_y'], | |
self.data_aug_params['rotation_z'], | |
self.data_aug_params['scale_range']) | |
self.data_aug_params["scale_range"] = (0.7, 1.4) | |
self.data_aug_params["do_elastic"] = False | |
self.data_aug_params['selected_seg_channels'] = [0] | |
self.data_aug_params['patch_size_for_spatialtransform'] = self.patch_size | |
def initialize(self, training=True, force_load_plans=False): | |
""" | |
removed deep supervision | |
:return: | |
""" | |
if not self.was_initialized: | |
maybe_mkdir_p(self.output_folder) | |
if force_load_plans or (self.plans is None): | |
self.load_plans_file() | |
self.process_plans(self.plans) | |
self.setup_DA_params() | |
self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + | |
"_stage%d" % self.stage) | |
if training: | |
self.dl_tr, self.dl_val = self.get_basic_generators() | |
if self.unpack_data: | |
print("unpacking dataset") | |
unpack_dataset(self.folder_with_preprocessed_data) | |
print("done") | |
else: | |
print( | |
"INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " | |
"will wait all winter for your model to finish!") | |
assert self.deep_supervision_scales is None | |
self.tr_gen, self.val_gen = get_moreDA_augmentation(self.dl_tr, self.dl_val, | |
self.data_aug_params[ | |
'patch_size_for_spatialtransform'], | |
self.data_aug_params, | |
deep_supervision_scales=self.deep_supervision_scales, | |
classes=None, | |
pin_memory=self.pin_memory) | |
self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), | |
also_print_to_console=False) | |
self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), | |
also_print_to_console=False) | |
else: | |
pass | |
self.initialize_network() | |
self.initialize_optimizer_and_scheduler() | |
assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel)) | |
else: | |
self.print_to_log_file('self.was_initialized is True, not running self.initialize again') | |
self.was_initialized = True | |
def initialize_network(self): | |
""" | |
changed deep supervision to False | |
:return: | |
""" | |
if self.threeD: | |
conv_op = nn.Conv3d | |
dropout_op = nn.Dropout3d | |
norm_op = nn.InstanceNorm3d | |
else: | |
conv_op = nn.Conv2d | |
dropout_op = nn.Dropout2d | |
norm_op = nn.InstanceNorm2d | |
norm_op_kwargs = {'eps': 1e-5, 'affine': True} | |
dropout_op_kwargs = {'p': 0, 'inplace': True} | |
net_nonlin = nn.LeakyReLU | |
net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} | |
self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes, | |
len(self.net_num_pool_op_kernel_sizes), | |
self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, | |
net_nonlin, net_nonlin_kwargs, False, False, lambda x: x, InitWeights_He(1e-2), | |
self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True) | |
if torch.cuda.is_available(): | |
self.network.cuda() | |
self.network.inference_apply_nonlin = softmax_helper | |
def run_online_evaluation(self, output, target): | |
return nnUNetTrainer.run_online_evaluation(self, output, target) | |