ho11laqe's picture
init
ecf08bc
raw
history blame
5.01 kB
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.transforms.abstract_transforms import Compose
from batchgenerators.transforms.channel_selection_transforms import DataChannelSelectionTransform, \
SegChannelSelectionTransform
from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
from nnunet.training.data_augmentation.custom_transforms import ConvertSegmentationToRegionsTransform
from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params
from nnunet.training.data_augmentation.downsampling import DownsampleSegForDSTransform3, DownsampleSegForDSTransform2
try:
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
except ImportError as ie:
NonDetMultiThreadedAugmenter = None
def get_no_augmentation(dataloader_train, dataloader_val, params=default_3D_augmentation_params,
deep_supervision_scales=None, soft_ds=False,
classes=None, pin_memory=True, regions=None):
"""
use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation
"""
tr_transforms = []
if params.get("selected_data_channels") is not None:
tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
if params.get("selected_seg_channels") is not None:
tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
tr_transforms.append(RemoveLabelTransform(-1, 0))
tr_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
if deep_supervision_scales is not None:
if soft_ds:
assert classes is not None
tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
else:
tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
output_key='target'))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
tr_transforms = Compose(tr_transforms)
batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
params.get("num_cached_per_thread"),
seeds=range(params.get('num_threads')), pin_memory=pin_memory)
batchgenerator_train.restart()
val_transforms = []
val_transforms.append(RemoveLabelTransform(-1, 0))
if params.get("selected_data_channels") is not None:
val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
if params.get("selected_seg_channels") is not None:
val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
val_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
if deep_supervision_scales is not None:
if soft_ds:
assert classes is not None
val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
else:
val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
output_key='target'))
val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
val_transforms = Compose(val_transforms)
batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
params.get("num_cached_per_thread"),
seeds=range(max(params.get('num_threads') // 2, 1)),
pin_memory=pin_memory)
batchgenerator_val.restart()
return batchgenerator_train, batchgenerator_val