# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter from batchgenerators.transforms.abstract_transforms import Compose from batchgenerators.transforms.channel_selection_transforms import DataChannelSelectionTransform, \ SegChannelSelectionTransform from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor from nnunet.training.data_augmentation.custom_transforms import ConvertSegmentationToRegionsTransform from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params from nnunet.training.data_augmentation.downsampling import DownsampleSegForDSTransform3, DownsampleSegForDSTransform2 try: from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter except ImportError as ie: NonDetMultiThreadedAugmenter = None def get_no_augmentation(dataloader_train, dataloader_val, params=default_3D_augmentation_params, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True, regions=None): """ use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation """ tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=range(params.get('num_threads')), pin_memory=pin_memory) batchgenerator_train.restart() val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=range(max(params.get('num_threads') // 2, 1)), pin_memory=pin_memory) batchgenerator_val.restart() return batchgenerator_train, batchgenerator_val