ho11laqe's picture
history blame
No virus
10.5 kB
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from copy import deepcopy
from nnunet.network_architecture.generic_UNet import Generic_UNet
import SimpleITK as sitk
import shutil
from batchgenerators.utilities.file_and_folder_operations import join
def split_4d_nifti(filename, output_folder):
img_itk = sitk.ReadImage(filename)
dim = img_itk.GetDimension()
file_base = filename.split("/")[-1]
if dim == 3:
shutil.copy(filename, join(output_folder, file_base[:-7] + "_0000.nii.gz"))
elif dim != 4:
raise RuntimeError("Unexpected dimensionality: %d of file %s, cannot split" % (dim, filename))
img_npy = sitk.GetArrayFromImage(img_itk)
spacing = img_itk.GetSpacing()
origin = img_itk.GetOrigin()
direction = np.array(img_itk.GetDirection()).reshape(4,4)
# now modify these to remove the fourth dimension
spacing = tuple(list(spacing[:-1]))
origin = tuple(list(origin[:-1]))
direction = tuple(direction[:-1, :-1].reshape(-1))
for i, t in enumerate(range(img_npy.shape[0])):
img = img_npy[t]
img_itk_new = sitk.GetImageFromArray(img)
sitk.WriteImage(img_itk_new, join(output_folder, file_base[:-7] + "_%04.0d.nii.gz" % i))
def get_pool_and_conv_props_poolLateV2(patch_size, min_feature_map_size, max_numpool, spacing):
:param spacing:
:param patch_size:
:param min_feature_map_size: min edge length of feature maps in bottleneck
initial_spacing = deepcopy(spacing)
reach = max(initial_spacing)
dim = len(patch_size)
num_pool_per_axis = get_network_numpool(patch_size, max_numpool, min_feature_map_size)
net_num_pool_op_kernel_sizes = []
net_conv_kernel_sizes = []
net_numpool = max(num_pool_per_axis)
current_spacing = spacing
for p in range(net_numpool):
reached = [current_spacing[i] / reach > 0.5 for i in range(dim)]
pool = [2 if num_pool_per_axis[i] + p >= net_numpool else 1 for i in range(dim)]
if all(reached):
conv = [3] * dim
conv = [3 if not reached[i] else 1 for i in range(dim)]
current_spacing = [i * j for i, j in zip(current_spacing, pool)]
net_conv_kernel_sizes.append([3] * dim)
must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis)
patch_size = pad_shape(patch_size, must_be_divisible_by)
# we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here
return num_pool_per_axis, net_num_pool_op_kernel_sizes, net_conv_kernel_sizes, patch_size, must_be_divisible_by
def get_pool_and_conv_props(spacing, patch_size, min_feature_map_size, max_numpool):
:param spacing:
:param patch_size:
:param min_feature_map_size: min edge length of feature maps in bottleneck
dim = len(spacing)
current_spacing = deepcopy(list(spacing))
current_size = deepcopy(list(patch_size))
pool_op_kernel_sizes = []
conv_kernel_sizes = []
num_pool_per_axis = [0] * dim
while True:
# This is a problem because sometimes we have spacing 20, 50, 50 and we want to still keep pooling.
# Here we would stop however. This is not what we want! Fixed in get_pool_and_conv_propsv2
min_spacing = min(current_spacing)
valid_axes_for_pool = [i for i in range(dim) if current_spacing[i] / min_spacing < 2]
axes = []
for a in range(dim):
my_spacing = current_spacing[a]
partners = [i for i in range(dim) if current_spacing[i] / my_spacing < 2 and my_spacing / current_spacing[i] < 2]
if len(partners) > len(axes):
axes = partners
conv_kernel_size = [3 if i in axes else 1 for i in range(dim)]
# exclude axes that we cannot pool further because of min_feature_map_size constraint
#before = len(valid_axes_for_pool)
valid_axes_for_pool = [i for i in valid_axes_for_pool if current_size[i] >= 2*min_feature_map_size]
#after = len(valid_axes_for_pool)
#if after == 1 and before > 1:
# break
valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool]
if len(valid_axes_for_pool) == 0:
#print(current_spacing, current_size)
other_axes = [i for i in range(dim) if i not in valid_axes_for_pool]
pool_kernel_sizes = [0] * dim
for v in valid_axes_for_pool:
pool_kernel_sizes[v] = 2
num_pool_per_axis[v] += 1
current_spacing[v] *= 2
current_size[v] = np.ceil(current_size[v] / 2)
for nv in other_axes:
pool_kernel_sizes[nv] = 1
must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis)
patch_size = pad_shape(patch_size, must_be_divisible_by)
# we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here
return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by
def get_pool_and_conv_props_v2(spacing, patch_size, min_feature_map_size, max_numpool):
:param spacing:
:param patch_size:
:param min_feature_map_size: min edge length of feature maps in bottleneck
dim = len(spacing)
current_spacing = deepcopy(list(spacing))
current_size = deepcopy(list(patch_size))
pool_op_kernel_sizes = []
conv_kernel_sizes = []
num_pool_per_axis = [0] * dim
kernel_size = [1] * dim
while True:
# exclude axes that we cannot pool further because of min_feature_map_size constraint
valid_axes_for_pool = [i for i in range(dim) if current_size[i] >= 2*min_feature_map_size]
if len(valid_axes_for_pool) < 1:
spacings_of_axes = [current_spacing[i] for i in valid_axes_for_pool]
# find axis that are within factor of 2 within smallest spacing
min_spacing_of_valid = min(spacings_of_axes)
valid_axes_for_pool = [i for i in valid_axes_for_pool if current_spacing[i] / min_spacing_of_valid < 2]
# max_numpool constraint
valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool]
if len(valid_axes_for_pool) == 1:
if current_size[valid_axes_for_pool[0]] >= 3 * min_feature_map_size:
if len(valid_axes_for_pool) < 1:
# now we need to find kernel sizes
# kernel sizes are initialized to 1. They are successively set to 3 when their associated axis becomes within
# factor 2 of min_spacing. Once they are 3 they remain 3
for d in range(dim):
if kernel_size[d] == 3:
if spacings_of_axes[d] / min(current_spacing) < 2:
kernel_size[d] = 3
other_axes = [i for i in range(dim) if i not in valid_axes_for_pool]
pool_kernel_sizes = [0] * dim
for v in valid_axes_for_pool:
pool_kernel_sizes[v] = 2
num_pool_per_axis[v] += 1
current_spacing[v] *= 2
current_size[v] = np.ceil(current_size[v] / 2)
for nv in other_axes:
pool_kernel_sizes[nv] = 1
must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis)
patch_size = pad_shape(patch_size, must_be_divisible_by)
# we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here
return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by
def get_shape_must_be_divisible_by(net_numpool_per_axis):
return 2 ** np.array(net_numpool_per_axis)
def pad_shape(shape, must_be_divisible_by):
pads shape so that it is divisibly by must_be_divisible_by
:param shape:
:param must_be_divisible_by:
if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)):
must_be_divisible_by = [must_be_divisible_by] * len(shape)
assert len(must_be_divisible_by) == len(shape)
new_shp = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i] for i in range(len(shape))]
for i in range(len(shape)):
if shape[i] % must_be_divisible_by[i] == 0:
new_shp[i] -= must_be_divisible_by[i]
new_shp = np.array(new_shp).astype(int)
return new_shp
def get_network_numpool(patch_size, maxpool_cap=999, min_feature_map_size=4):
network_numpool_per_axis = np.floor([np.log(i / min_feature_map_size) / np.log(2) for i in patch_size]).astype(int)
network_numpool_per_axis = [min(i, maxpool_cap) for i in network_numpool_per_axis]
return network_numpool_per_axis
if __name__ == '__main__':
# trying to fix https://github.com/MIC-DKFZ/nnUNet/issues/261
median_shape = [24, 504, 512]
spacing = [5.9999094, 0.50781202, 0.50781202]
num_pool_per_axis, net_num_pool_op_kernel_sizes, net_conv_kernel_sizes, patch_size, must_be_divisible_by = get_pool_and_conv_props_poolLateV2(median_shape, min_feature_map_size=4, max_numpool=999, spacing=spacing)