|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from copy import deepcopy |
|
|
|
import numpy as np |
|
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter |
|
from batchgenerators.transforms.abstract_transforms import Compose |
|
from batchgenerators.transforms.channel_selection_transforms import DataChannelSelectionTransform, \ |
|
SegChannelSelectionTransform |
|
from batchgenerators.transforms.color_transforms import GammaTransform |
|
from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform |
|
from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor |
|
|
|
from nnunet.training.data_augmentation.custom_transforms import Convert3DTo2DTransform, Convert2DTo3DTransform, \ |
|
MaskTransform, ConvertSegmentationToRegionsTransform |
|
from nnunet.training.data_augmentation.pyramid_augmentations import MoveSegAsOneHotToData, \ |
|
ApplyRandomBinaryOperatorTransform, \ |
|
RemoveRandomConnectedComponentFromOneHotEncodingTransform |
|
|
|
try: |
|
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter |
|
except ImportError as ie: |
|
NonDetMultiThreadedAugmenter = None |
|
|
|
|
|
default_3D_augmentation_params = { |
|
"selected_data_channels": None, |
|
"selected_seg_channels": [0,1,2], |
|
|
|
"do_elastic": True, |
|
"elastic_deform_alpha": (0., 900.), |
|
"elastic_deform_sigma": (9., 13.), |
|
"p_eldef": 0.2, |
|
|
|
"do_scaling": True, |
|
"scale_range": (0.85, 1.25), |
|
"independent_scale_factor_for_each_axis": False, |
|
"p_independent_scale_per_axis": 1, |
|
"p_scale": 0.2, |
|
|
|
"do_rotation": True, |
|
"rotation_x": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), |
|
"rotation_y": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), |
|
"rotation_z": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), |
|
"rotation_p_per_axis": 1, |
|
"p_rot": 0.2, |
|
|
|
"random_crop": False, |
|
"random_crop_dist_to_border": None, |
|
|
|
"do_gamma": True, |
|
"gamma_retain_stats": True, |
|
"gamma_range": (0.7, 1.5), |
|
"p_gamma": 0.3, |
|
|
|
"do_mirror": True, |
|
"mirror_axes": (0, 1, 2), |
|
|
|
"dummy_2D": False, |
|
"mask_was_used_for_normalization": None, |
|
"border_mode_data": "constant", |
|
|
|
"all_segmentation_labels": None, |
|
"move_last_seg_chanel_to_data": False, |
|
"cascade_do_cascade_augmentations": False, |
|
"cascade_random_binary_transform_p": 0.4, |
|
"cascade_random_binary_transform_p_per_label": 1, |
|
"cascade_random_binary_transform_size": (1, 8), |
|
"cascade_remove_conn_comp_p": 0.2, |
|
"cascade_remove_conn_comp_max_size_percent_threshold": 0.15, |
|
"cascade_remove_conn_comp_fill_with_other_class_p": 0.0, |
|
|
|
"do_additive_brightness": False, |
|
"additive_brightness_p_per_sample": 0.15, |
|
"additive_brightness_p_per_channel": 0.5, |
|
"additive_brightness_mu": 0.0, |
|
"additive_brightness_sigma": 0.1, |
|
|
|
"num_threads": 12 if 'nnUNet_n_proc_DA' not in os.environ else int(os.environ['nnUNet_n_proc_DA']), |
|
"num_cached_per_thread": 1, |
|
} |
|
|
|
default_2D_augmentation_params = deepcopy(default_3D_augmentation_params) |
|
|
|
default_2D_augmentation_params["elastic_deform_alpha"] = (0., 200.) |
|
default_2D_augmentation_params["elastic_deform_sigma"] = (9., 13.) |
|
default_2D_augmentation_params["rotation_x"] = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi) |
|
default_2D_augmentation_params["rotation_y"] = (-0. / 360 * 2. * np.pi, 0. / 360 * 2. * np.pi) |
|
default_2D_augmentation_params["rotation_z"] = (-0. / 360 * 2. * np.pi, 0. / 360 * 2. * np.pi) |
|
|
|
|
|
|
|
|
|
default_2D_augmentation_params["dummy_2D"] = False |
|
default_2D_augmentation_params["mirror_axes"] = (0, 1) |
|
|
|
|
|
def get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range): |
|
if isinstance(rot_x, (tuple, list)): |
|
rot_x = max(np.abs(rot_x)) |
|
if isinstance(rot_y, (tuple, list)): |
|
rot_y = max(np.abs(rot_y)) |
|
if isinstance(rot_z, (tuple, list)): |
|
rot_z = max(np.abs(rot_z)) |
|
rot_x = min(90 / 360 * 2. * np.pi, rot_x) |
|
rot_y = min(90 / 360 * 2. * np.pi, rot_y) |
|
rot_z = min(90 / 360 * 2. * np.pi, rot_z) |
|
from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d |
|
coords = np.array(final_patch_size) |
|
final_shape = np.copy(coords) |
|
if len(coords) == 3: |
|
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0) |
|
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0) |
|
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0) |
|
elif len(coords) == 2: |
|
final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0) |
|
final_shape /= min(scale_range) |
|
return final_shape.astype(int) |
|
|
|
|
|
def get_default_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, |
|
border_val_seg=-1, pin_memory=True, |
|
seeds_train=None, seeds_val=None, regions=None): |
|
assert params.get('mirror') is None, "old version of params, use new keyword do_mirror" |
|
tr_transforms = [] |
|
|
|
if params.get("selected_data_channels") is not None: |
|
tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) |
|
|
|
if params.get("selected_seg_channels") is not None: |
|
tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) |
|
|
|
|
|
if params.get("dummy_2D") is not None and params.get("dummy_2D"): |
|
tr_transforms.append(Convert3DTo2DTransform()) |
|
patch_size_spatial = patch_size[1:] |
|
else: |
|
patch_size_spatial = patch_size |
|
|
|
tr_transforms.append(SpatialTransform( |
|
patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), |
|
alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), |
|
do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), |
|
angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), |
|
border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", |
|
border_cval_seg=border_val_seg, |
|
order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), |
|
p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), |
|
independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis") |
|
)) |
|
if params.get("dummy_2D") is not None and params.get("dummy_2D"): |
|
tr_transforms.append(Convert2DTo3DTransform()) |
|
|
|
if params.get("do_gamma"): |
|
tr_transforms.append( |
|
GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), |
|
p_per_sample=params["p_gamma"])) |
|
|
|
if params.get("do_mirror"): |
|
tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) |
|
|
|
if params.get("mask_was_used_for_normalization") is not None: |
|
mask_was_used_for_normalization = params.get("mask_was_used_for_normalization") |
|
tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) |
|
|
|
tr_transforms.append(RemoveLabelTransform(-1, 0)) |
|
|
|
if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): |
|
tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) |
|
if params.get("cascade_do_cascade_augmentations") and not None and params.get( |
|
"cascade_do_cascade_augmentations"): |
|
tr_transforms.append(ApplyRandomBinaryOperatorTransform( |
|
channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), |
|
p_per_sample=params.get("cascade_random_binary_transform_p"), |
|
key="data", |
|
strel_size=params.get("cascade_random_binary_transform_size"))) |
|
tr_transforms.append(RemoveRandomConnectedComponentFromOneHotEncodingTransform( |
|
channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), |
|
key="data", |
|
p_per_sample=params.get("cascade_remove_conn_comp_p"), |
|
fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"), |
|
dont_do_if_covers_more_than_X_percent=params.get("cascade_remove_conn_comp_fill_with_other_class_p"))) |
|
|
|
tr_transforms.append(RenameTransform('seg', 'target', True)) |
|
|
|
if regions is not None: |
|
tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) |
|
|
|
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) |
|
|
|
tr_transforms = Compose(tr_transforms) |
|
|
|
|
|
|
|
|
|
batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), |
|
params.get("num_cached_per_thread"), seeds=seeds_train, |
|
pin_memory=pin_memory) |
|
|
|
val_transforms = [] |
|
val_transforms.append(RemoveLabelTransform(-1, 0)) |
|
if params.get("selected_data_channels") is not None: |
|
val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) |
|
if params.get("selected_seg_channels") is not None: |
|
val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) |
|
|
|
if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): |
|
val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) |
|
|
|
val_transforms.append(RenameTransform('seg', 'target', True)) |
|
|
|
if regions is not None: |
|
val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) |
|
|
|
val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) |
|
val_transforms = Compose(val_transforms) |
|
|
|
|
|
batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), |
|
params.get("num_cached_per_thread"), seeds=seeds_val, |
|
pin_memory=pin_memory) |
|
return batchgenerator_train, batchgenerator_val |
|
|
|
|
|
if __name__ == "__main__": |
|
from nnunet.training.dataloading.dataset_loading import DataLoader3D, load_dataset |
|
from nnunet.paths import preprocessing_output_dir |
|
import os |
|
import pickle |
|
|
|
t = "Task002_Heart" |
|
p = os.path.join(preprocessing_output_dir, t) |
|
dataset = load_dataset(p, 0) |
|
with open(os.path.join(p, "plans.pkl"), 'rb') as f: |
|
plans = pickle.load(f) |
|
|
|
basic_patch_size = get_patch_size(np.array(plans['stage_properties'][0].patch_size), |
|
default_3D_augmentation_params['rotation_x'], |
|
default_3D_augmentation_params['rotation_y'], |
|
default_3D_augmentation_params['rotation_z'], |
|
default_3D_augmentation_params['scale_range']) |
|
|
|
dl = DataLoader3D(dataset, basic_patch_size, np.array(plans['stage_properties'][0].patch_size).astype(int), 1) |
|
tr, val = get_default_augmentation(dl, dl, np.array(plans['stage_properties'][0].patch_size).astype(int)) |
|
|