|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
from copy import deepcopy |
|
from nnunet.network_architecture.generic_UNet import Generic_UNet |
|
import SimpleITK as sitk |
|
import shutil |
|
from batchgenerators.utilities.file_and_folder_operations import join |
|
|
|
|
|
def split_4d_nifti(filename, output_folder): |
|
img_itk = sitk.ReadImage(filename) |
|
dim = img_itk.GetDimension() |
|
file_base = filename.split("/")[-1] |
|
if dim == 3: |
|
shutil.copy(filename, join(output_folder, file_base[:-7] + "_0000.nii.gz")) |
|
return |
|
elif dim != 4: |
|
raise RuntimeError("Unexpected dimensionality: %d of file %s, cannot split" % (dim, filename)) |
|
else: |
|
img_npy = sitk.GetArrayFromImage(img_itk) |
|
spacing = img_itk.GetSpacing() |
|
origin = img_itk.GetOrigin() |
|
direction = np.array(img_itk.GetDirection()).reshape(4,4) |
|
|
|
spacing = tuple(list(spacing[:-1])) |
|
origin = tuple(list(origin[:-1])) |
|
direction = tuple(direction[:-1, :-1].reshape(-1)) |
|
for i, t in enumerate(range(img_npy.shape[0])): |
|
img = img_npy[t] |
|
img_itk_new = sitk.GetImageFromArray(img) |
|
img_itk_new.SetSpacing(spacing) |
|
img_itk_new.SetOrigin(origin) |
|
img_itk_new.SetDirection(direction) |
|
sitk.WriteImage(img_itk_new, join(output_folder, file_base[:-7] + "_%04.0d.nii.gz" % i)) |
|
|
|
|
|
def get_pool_and_conv_props_poolLateV2(patch_size, min_feature_map_size, max_numpool, spacing): |
|
""" |
|
|
|
:param spacing: |
|
:param patch_size: |
|
:param min_feature_map_size: min edge length of feature maps in bottleneck |
|
:return: |
|
""" |
|
initial_spacing = deepcopy(spacing) |
|
reach = max(initial_spacing) |
|
dim = len(patch_size) |
|
|
|
num_pool_per_axis = get_network_numpool(patch_size, max_numpool, min_feature_map_size) |
|
|
|
net_num_pool_op_kernel_sizes = [] |
|
net_conv_kernel_sizes = [] |
|
net_numpool = max(num_pool_per_axis) |
|
|
|
current_spacing = spacing |
|
for p in range(net_numpool): |
|
reached = [current_spacing[i] / reach > 0.5 for i in range(dim)] |
|
pool = [2 if num_pool_per_axis[i] + p >= net_numpool else 1 for i in range(dim)] |
|
if all(reached): |
|
conv = [3] * dim |
|
else: |
|
conv = [3 if not reached[i] else 1 for i in range(dim)] |
|
net_num_pool_op_kernel_sizes.append(pool) |
|
net_conv_kernel_sizes.append(conv) |
|
current_spacing = [i * j for i, j in zip(current_spacing, pool)] |
|
|
|
net_conv_kernel_sizes.append([3] * dim) |
|
|
|
must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis) |
|
patch_size = pad_shape(patch_size, must_be_divisible_by) |
|
|
|
|
|
return num_pool_per_axis, net_num_pool_op_kernel_sizes, net_conv_kernel_sizes, patch_size, must_be_divisible_by |
|
|
|
|
|
def get_pool_and_conv_props(spacing, patch_size, min_feature_map_size, max_numpool): |
|
""" |
|
|
|
:param spacing: |
|
:param patch_size: |
|
:param min_feature_map_size: min edge length of feature maps in bottleneck |
|
:return: |
|
""" |
|
dim = len(spacing) |
|
|
|
current_spacing = deepcopy(list(spacing)) |
|
current_size = deepcopy(list(patch_size)) |
|
|
|
pool_op_kernel_sizes = [] |
|
conv_kernel_sizes = [] |
|
|
|
num_pool_per_axis = [0] * dim |
|
|
|
while True: |
|
|
|
|
|
min_spacing = min(current_spacing) |
|
valid_axes_for_pool = [i for i in range(dim) if current_spacing[i] / min_spacing < 2] |
|
axes = [] |
|
for a in range(dim): |
|
my_spacing = current_spacing[a] |
|
partners = [i for i in range(dim) if current_spacing[i] / my_spacing < 2 and my_spacing / current_spacing[i] < 2] |
|
if len(partners) > len(axes): |
|
axes = partners |
|
conv_kernel_size = [3 if i in axes else 1 for i in range(dim)] |
|
|
|
|
|
|
|
valid_axes_for_pool = [i for i in valid_axes_for_pool if current_size[i] >= 2*min_feature_map_size] |
|
|
|
|
|
|
|
|
|
valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool] |
|
|
|
if len(valid_axes_for_pool) == 0: |
|
break |
|
|
|
|
|
|
|
other_axes = [i for i in range(dim) if i not in valid_axes_for_pool] |
|
|
|
pool_kernel_sizes = [0] * dim |
|
for v in valid_axes_for_pool: |
|
pool_kernel_sizes[v] = 2 |
|
num_pool_per_axis[v] += 1 |
|
current_spacing[v] *= 2 |
|
current_size[v] = np.ceil(current_size[v] / 2) |
|
for nv in other_axes: |
|
pool_kernel_sizes[nv] = 1 |
|
|
|
pool_op_kernel_sizes.append(pool_kernel_sizes) |
|
conv_kernel_sizes.append(conv_kernel_size) |
|
|
|
|
|
must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis) |
|
patch_size = pad_shape(patch_size, must_be_divisible_by) |
|
|
|
|
|
conv_kernel_sizes.append([3]*dim) |
|
return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by |
|
|
|
|
|
def get_pool_and_conv_props_v2(spacing, patch_size, min_feature_map_size, max_numpool): |
|
""" |
|
|
|
:param spacing: |
|
:param patch_size: |
|
:param min_feature_map_size: min edge length of feature maps in bottleneck |
|
:return: |
|
""" |
|
dim = len(spacing) |
|
|
|
current_spacing = deepcopy(list(spacing)) |
|
current_size = deepcopy(list(patch_size)) |
|
|
|
pool_op_kernel_sizes = [] |
|
conv_kernel_sizes = [] |
|
|
|
num_pool_per_axis = [0] * dim |
|
kernel_size = [1] * dim |
|
|
|
while True: |
|
|
|
valid_axes_for_pool = [i for i in range(dim) if current_size[i] >= 2*min_feature_map_size] |
|
if len(valid_axes_for_pool) < 1: |
|
break |
|
|
|
spacings_of_axes = [current_spacing[i] for i in valid_axes_for_pool] |
|
|
|
|
|
min_spacing_of_valid = min(spacings_of_axes) |
|
valid_axes_for_pool = [i for i in valid_axes_for_pool if current_spacing[i] / min_spacing_of_valid < 2] |
|
|
|
|
|
valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool] |
|
|
|
if len(valid_axes_for_pool) == 1: |
|
if current_size[valid_axes_for_pool[0]] >= 3 * min_feature_map_size: |
|
pass |
|
else: |
|
break |
|
if len(valid_axes_for_pool) < 1: |
|
break |
|
|
|
|
|
|
|
|
|
for d in range(dim): |
|
if kernel_size[d] == 3: |
|
continue |
|
else: |
|
if spacings_of_axes[d] / min(current_spacing) < 2: |
|
kernel_size[d] = 3 |
|
|
|
other_axes = [i for i in range(dim) if i not in valid_axes_for_pool] |
|
|
|
pool_kernel_sizes = [0] * dim |
|
for v in valid_axes_for_pool: |
|
pool_kernel_sizes[v] = 2 |
|
num_pool_per_axis[v] += 1 |
|
current_spacing[v] *= 2 |
|
current_size[v] = np.ceil(current_size[v] / 2) |
|
for nv in other_axes: |
|
pool_kernel_sizes[nv] = 1 |
|
|
|
pool_op_kernel_sizes.append(pool_kernel_sizes) |
|
conv_kernel_sizes.append(deepcopy(kernel_size)) |
|
|
|
|
|
must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis) |
|
patch_size = pad_shape(patch_size, must_be_divisible_by) |
|
|
|
|
|
conv_kernel_sizes.append([3]*dim) |
|
return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by |
|
|
|
|
|
def get_shape_must_be_divisible_by(net_numpool_per_axis): |
|
return 2 ** np.array(net_numpool_per_axis) |
|
|
|
|
|
def pad_shape(shape, must_be_divisible_by): |
|
""" |
|
pads shape so that it is divisibly by must_be_divisible_by |
|
:param shape: |
|
:param must_be_divisible_by: |
|
:return: |
|
""" |
|
if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)): |
|
must_be_divisible_by = [must_be_divisible_by] * len(shape) |
|
else: |
|
assert len(must_be_divisible_by) == len(shape) |
|
|
|
new_shp = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i] for i in range(len(shape))] |
|
|
|
for i in range(len(shape)): |
|
if shape[i] % must_be_divisible_by[i] == 0: |
|
new_shp[i] -= must_be_divisible_by[i] |
|
new_shp = np.array(new_shp).astype(int) |
|
return new_shp |
|
|
|
|
|
def get_network_numpool(patch_size, maxpool_cap=999, min_feature_map_size=4): |
|
network_numpool_per_axis = np.floor([np.log(i / min_feature_map_size) / np.log(2) for i in patch_size]).astype(int) |
|
network_numpool_per_axis = [min(i, maxpool_cap) for i in network_numpool_per_axis] |
|
return network_numpool_per_axis |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
median_shape = [24, 504, 512] |
|
spacing = [5.9999094, 0.50781202, 0.50781202] |
|
num_pool_per_axis, net_num_pool_op_kernel_sizes, net_conv_kernel_sizes, patch_size, must_be_divisible_by = get_pool_and_conv_props_poolLateV2(median_shape, min_feature_map_size=4, max_numpool=999, spacing=spacing) |
|
|