|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from batchgenerators.augmentations.utils import convert_seg_image_to_one_hot_encoding_batched, resize_segmentation |
|
from batchgenerators.transforms.abstract_transforms import AbstractTransform |
|
from torch.nn.functional import avg_pool2d, avg_pool3d |
|
import numpy as np |
|
|
|
|
|
class DownsampleSegForDSTransform3(AbstractTransform): |
|
''' |
|
returns one hot encodings of the segmentation maps if downsampling has occured (no one hot for highest resolution) |
|
downsampled segmentations are smooth, not 0/1 |
|
|
|
returns torch tensors, not numpy arrays! |
|
|
|
always uses seg channel 0!! |
|
|
|
you should always give classes! Otherwise weird stuff may happen |
|
''' |
|
def __init__(self, ds_scales=(1, 0.5, 0.25), input_key="seg", output_key="seg", classes=None): |
|
self.classes = classes |
|
self.output_key = output_key |
|
self.input_key = input_key |
|
self.ds_scales = ds_scales |
|
|
|
def __call__(self, **data_dict): |
|
data_dict[self.output_key] = downsample_seg_for_ds_transform3(data_dict[self.input_key][:, 0], self.ds_scales, self.classes) |
|
return data_dict |
|
|
|
|
|
def downsample_seg_for_ds_transform3(seg, ds_scales=((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25)), classes=None): |
|
output = [] |
|
one_hot = torch.from_numpy(convert_seg_image_to_one_hot_encoding_batched(seg, classes)) |
|
|
|
for s in ds_scales: |
|
if all([i == 1 for i in s]): |
|
output.append(torch.from_numpy(seg)) |
|
else: |
|
kernel_size = tuple(int(1 / i) for i in s) |
|
stride = kernel_size |
|
pad = tuple((i-1) // 2 for i in kernel_size) |
|
|
|
if len(s) == 2: |
|
pool_op = avg_pool2d |
|
elif len(s) == 3: |
|
pool_op = avg_pool3d |
|
else: |
|
raise RuntimeError() |
|
|
|
pooled = pool_op(one_hot, kernel_size, stride, pad, count_include_pad=False, ceil_mode=False) |
|
|
|
output.append(pooled) |
|
return output |
|
|
|
|
|
class DownsampleSegForDSTransform2(AbstractTransform): |
|
''' |
|
data_dict['output_key'] will be a list of segmentations scaled according to ds_scales |
|
''' |
|
def __init__(self, ds_scales=(1, 0.5, 0.25), order=0, input_key="seg", output_key="seg", axes=None): |
|
self.axes = axes |
|
self.output_key = output_key |
|
self.input_key = input_key |
|
self.order = order |
|
self.ds_scales = ds_scales |
|
|
|
def __call__(self, **data_dict): |
|
data_dict[self.output_key] = downsample_seg_for_ds_transform2(data_dict[self.input_key], self.ds_scales, |
|
self.order, self.axes) |
|
return data_dict |
|
|
|
|
|
def downsample_seg_for_ds_transform2(seg, ds_scales=((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25)), order=0, axes=None): |
|
if axes is None: |
|
axes = list(range(2, len(seg.shape))) |
|
output = [] |
|
for s in ds_scales: |
|
if all([i == 1 for i in s]): |
|
output.append(seg) |
|
else: |
|
new_shape = np.array(seg.shape).astype(float) |
|
for i, a in enumerate(axes): |
|
new_shape[a] *= s[i] |
|
new_shape = np.round(new_shape).astype(int) |
|
out_seg = np.zeros(new_shape, dtype=seg.dtype) |
|
for b in range(seg.shape[0]): |
|
for c in range(seg.shape[1]): |
|
out_seg[b, c] = resize_segmentation(seg[b, c], new_shape[2:], order) |
|
output.append(out_seg) |
|
return output |
|
|