File size: 12,509 Bytes
ecf08bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from copy import deepcopy
from typing import Union, Tuple
import numpy as np
import SimpleITK as sitk
from batchgenerators.augmentations.utils import resize_segmentation
from nnunet.preprocessing.preprocessing import get_lowres_axis, get_do_separate_z, resample_data_or_seg
from batchgenerators.utilities.file_and_folder_operations import *
def save_segmentation_nifti_from_softmax(segmentation_softmax: Union[str, np.ndarray], out_fname: str,
properties_dict: dict, order: int = 1,
region_class_order: Tuple[Tuple[int]] = None,
seg_postprogess_fn: callable = None, seg_postprocess_args: tuple = None,
resampled_npz_fname: str = None,
non_postprocessed_fname: str = None, force_separate_z: bool = None,
interpolation_order_z: int = 0, verbose: bool = True):
"""
This is a utility for writing segmentations to nifto and npz. It requires the data to have been preprocessed by
GenericPreprocessor because it depends on the property dictionary output (dct) to know the geometry of the original
data. segmentation_softmax does not have to have the same size in pixels as the original data, it will be
resampled to match that. This is generally useful because the spacings our networks operate on are most of the time
not the native spacings of the image data.
If seg_postprogess_fn is not None then seg_postprogess_fnseg_postprogess_fn(segmentation, *seg_postprocess_args)
will be called before nifto export
There is a problem with python process communication that prevents us from communicating obejcts
larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long
enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
patching system python code.) We circumvent that problem here by saving softmax_pred to a npy file that will
then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
filename or np.ndarray for segmentation_softmax and will handle this automatically
:param segmentation_softmax:
:param out_fname:
:param properties_dict:
:param order:
:param region_class_order:
:param seg_postprogess_fn:
:param seg_postprocess_args:
:param resampled_npz_fname:
:param non_postprocessed_fname:
:param force_separate_z: if None then we dynamically decide how to resample along z, if True/False then always
/never resample along z separately. Do not touch unless you know what you are doing
:param interpolation_order_z: if separate z resampling is done then this is the order for resampling in z
:param verbose:
:return:
"""
if verbose: print("force_separate_z:", force_separate_z, "interpolation order:", order)
if isinstance(segmentation_softmax, str):
assert isfile(segmentation_softmax), "If isinstance(segmentation_softmax, str) then " \
"isfile(segmentation_softmax) must be True"
del_file = deepcopy(segmentation_softmax)
segmentation_softmax = np.load(segmentation_softmax)
os.remove(del_file)
# first resample, then put result into bbox of cropping, then save
current_shape = segmentation_softmax.shape
shape_original_after_cropping = properties_dict.get('size_after_cropping')
shape_original_before_cropping = properties_dict.get('original_size_of_raw_data')
# current_spacing = dct.get('spacing_after_resampling')
# original_spacing = dct.get('original_spacing')
if np.any([i != j for i, j in zip(np.array(current_shape[1:]), np.array(shape_original_after_cropping))]):
if force_separate_z is None:
if get_do_separate_z(properties_dict.get('original_spacing')):
do_separate_z = True
lowres_axis = get_lowres_axis(properties_dict.get('original_spacing'))
elif get_do_separate_z(properties_dict.get('spacing_after_resampling')):
do_separate_z = True
lowres_axis = get_lowres_axis(properties_dict.get('spacing_after_resampling'))
else:
do_separate_z = False
lowres_axis = None
else:
do_separate_z = force_separate_z
if do_separate_z:
lowres_axis = get_lowres_axis(properties_dict.get('original_spacing'))
else:
lowres_axis = None
if lowres_axis is not None and len(lowres_axis) != 1:
# this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample
# separately in the out of plane axis
do_separate_z = False
if verbose: print("separate z:", do_separate_z, "lowres axis", lowres_axis)
seg_old_spacing = resample_data_or_seg(segmentation_softmax, shape_original_after_cropping, is_seg=False,
axis=lowres_axis, order=order, do_separate_z=do_separate_z,
order_z=interpolation_order_z)
# seg_old_spacing = resize_softmax_output(segmentation_softmax, shape_original_after_cropping, order=order)
else:
if verbose: print("no resampling necessary")
seg_old_spacing = segmentation_softmax
if resampled_npz_fname is not None:
np.savez_compressed(resampled_npz_fname, softmax=seg_old_spacing.astype(np.float16))
# this is needed for ensembling if the nonlinearity is sigmoid
if region_class_order is not None:
properties_dict['regions_class_order'] = region_class_order
save_pickle(properties_dict, resampled_npz_fname[:-4] + ".pkl")
if region_class_order is None:
labels = properties_dict['classes']
shape = np.array(seg_old_spacing.shape)[1:]
shape[0] = len(labels)
seg_old_space = np.zeros(shape)
counter = 0
for i, label in enumerate(labels):
classes = len(label)
if classes == 1:
seg_old_space[i] = seg_old_spacing[counter:classes + counter]
else:
seg_old_space[i] = seg_old_spacing[counter:classes+counter].argmax(0)[0]
counter += classes
seg_old_spacing = seg_old_space
else:
seg_old_spacing_final = np.zeros(seg_old_spacing.shape[1:])
for i, c in enumerate(region_class_order):
seg_old_spacing_final[seg_old_spacing[i] > 0.5] = c
seg_old_spacing = seg_old_spacing_final
bbox = properties_dict.get('crop_bbox')
if bbox is not None:
shape_original_before_cropping[0] = len(properties_dict['classes'])
seg_old_size = np.zeros(shape_original_before_cropping)
for c in range(3):
bbox[c][1] = np.min((bbox[c][0] + seg_old_spacing.shape[c], shape_original_before_cropping[c]))
seg_old_size[:,
bbox[1][0]:bbox[1][1],
bbox[2][0]:bbox[2][1]] = seg_old_spacing
else:
seg_old_size = seg_old_spacing
if seg_postprogess_fn is not None:
seg_old_size_postprocessed = seg_postprogess_fn(np.copy(seg_old_size), *seg_postprocess_args)
else:
seg_old_size_postprocessed = seg_old_size
seg_resized_itk = sitk.GetImageFromArray(seg_old_size_postprocessed.astype(float))
seg_resized_itk.SetSpacing(properties_dict['itk_spacing'])
seg_resized_itk.SetOrigin(properties_dict['itk_origin'])
seg_resized_itk.SetDirection(properties_dict['itk_direction'])
sitk.WriteImage(seg_resized_itk, out_fname)
if (non_postprocessed_fname is not None) and (seg_postprogess_fn is not None):
seg_resized_itk = sitk.GetImageFromArray(seg_old_size.astype(np.uint8))
seg_resized_itk.SetSpacing(properties_dict['itk_spacing'])
seg_resized_itk.SetOrigin(properties_dict['itk_origin'])
seg_resized_itk.SetDirection(properties_dict['itk_direction'])
sitk.WriteImage(seg_resized_itk, non_postprocessed_fname)
def save_segmentation_nifti(segmentation, out_fname, dct, order=1, force_separate_z=None, order_z=0):
"""
faster and uses less ram than save_segmentation_nifti_from_softmax, but maybe less precise and also does not support
softmax export (which is needed for ensembling). So it's a niche function that may be useful in some cases.
:param segmentation:
:param out_fname:
:param dct:
:param order:
:param force_separate_z:
:return:
"""
# suppress output
print("force_separate_z:", force_separate_z, "interpolation order:", order)
sys.stdout = open(os.devnull, 'w')
if isinstance(segmentation, str):
assert isfile(segmentation), "If isinstance(segmentation_softmax, str) then " \
"isfile(segmentation_softmax) must be True"
del_file = deepcopy(segmentation)
segmentation = np.load(segmentation)
os.remove(del_file)
# first resample, then put result into bbox of cropping, then save
current_shape = segmentation.shape
shape_original_after_cropping = dct.get('size_after_cropping')
shape_original_before_cropping = dct.get('original_size_of_raw_data')
# current_spacing = dct.get('spacing_after_resampling')
# original_spacing = dct.get('original_spacing')
if np.any(np.array(current_shape) != np.array(shape_original_after_cropping)):
if order == 0:
seg_old_spacing = resize_segmentation(segmentation, shape_original_after_cropping, 0)
else:
if force_separate_z is None:
if get_do_separate_z(dct.get('original_spacing')):
do_separate_z = True
lowres_axis = get_lowres_axis(dct.get('original_spacing'))
elif get_do_separate_z(dct.get('spacing_after_resampling')):
do_separate_z = True
lowres_axis = get_lowres_axis(dct.get('spacing_after_resampling'))
else:
do_separate_z = False
lowres_axis = None
else:
do_separate_z = force_separate_z
if do_separate_z:
lowres_axis = get_lowres_axis(dct.get('original_spacing'))
else:
lowres_axis = None
print("separate z:", do_separate_z, "lowres axis", lowres_axis)
seg_old_spacing = resample_data_or_seg(segmentation[None], shape_original_after_cropping, is_seg=True,
axis=lowres_axis, order=order, do_separate_z=do_separate_z,
order_z=order_z)[0]
else:
seg_old_spacing = segmentation
bbox = dct.get('crop_bbox')
if bbox is not None:
seg_old_size = np.zeros(shape_original_before_cropping)
for c in range(3):
bbox[c][1] = np.min((bbox[c][0] + seg_old_spacing.shape[c], shape_original_before_cropping[c]))
seg_old_size[bbox[0][0]:bbox[0][1],
bbox[1][0]:bbox[1][1],
bbox[2][0]:bbox[2][1]] = seg_old_spacing
else:
seg_old_size = seg_old_spacing
seg_resized_itk = sitk.GetImageFromArray(seg_old_size.astype(np.uint8))
seg_resized_itk.SetSpacing(dct['itk_spacing'])
seg_resized_itk.SetOrigin(dct['itk_origin'])
seg_resized_itk.SetDirection(dct['itk_direction'])
sitk.WriteImage(seg_resized_itk, out_fname)
sys.stdout = sys.__stdout__
|