# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p from nnunet.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation try: from meddec.model_training.ablation_studies.new_nnUNet_candidates.nnUNetTrainerCandidate23_softDeepSupervision4 import \ MyDSLoss4 except ImportError: MyDSLoss4 = None from nnunet.network_architecture.neural_network import SegmentationNetwork from nnunet.training.dataloading.dataset_loading import unpack_dataset from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2 from torch import nn import numpy as np class nnUNetTrainerV2_softDeepSupervision(nnUNetTrainerV2): def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, unpack_data=True, deterministic=True, fp16=False): super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, fp16) self.loss = None # we take care of that later def initialize(self, training=True, force_load_plans=False): """ - replaced get_default_augmentation with get_moreDA_augmentation - only run this code once - loss function wrapper for deep supervision :param training: :param force_load_plans: :return: """ if not self.was_initialized: maybe_mkdir_p(self.output_folder) if force_load_plans or (self.plans is None): self.load_plans_file() self.process_plans(self.plans) self.setup_DA_params() ################# Here we wrap the loss for deep supervision ############ # we need to know the number of outputs of the network net_numpool = len(self.net_num_pool_op_kernel_sizes) # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss weights = np.array([1 / (2 ** i) for i in range(net_numpool)]) # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 mask = np.array([True if i < net_numpool - 1 else False for i in range(net_numpool)]) weights[~mask] = 0 weights = weights / weights.sum() # now wrap the loss if MyDSLoss4 is None: raise RuntimeError("This aint ready for prime time yet") self.loss = MyDSLoss4(self.batch_dice, weights) #self.loss = MultipleOutputLoss2(self.loss, weights) ################# END ################### self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) if training: self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: print("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) print("done") else: print( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") self.tr_gen, self.val_gen = get_moreDA_augmentation(self.dl_tr, self.dl_val, self.data_aug_params[ 'patch_size_for_spatialtransform'], self.data_aug_params, deep_supervision_scales=self.deep_supervision_scales, soft_ds=True, classes=[0] + list(self.classes), pin_memory=self.pin_memory) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel)) else: self.print_to_log_file('self.was_initialized is True, not running self.initialize again') self.was_initialized = True def run_online_evaluation(self, output, target): """ due to deep supervision the return value and the reference are now lists of tensors. We only need the full resolution output because this is what we are interested in in the end. The others are ignored :param output: :param target: :return: """ target = target[0][:, None] # we need to restore color channel dimension here to be compatible with previous code output = output[0] return nnUNetTrainer.run_online_evaluation(self, output, target)