|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from copy import deepcopy |
|
|
|
from batchgenerators.transforms.abstract_transforms import AbstractTransform |
|
from skimage.morphology import label, ball |
|
from skimage.morphology.binary import binary_erosion, binary_dilation, binary_closing, binary_opening |
|
import numpy as np |
|
|
|
|
|
class RemoveRandomConnectedComponentFromOneHotEncodingTransform(AbstractTransform): |
|
def __init__(self, channel_idx, key="data", p_per_sample=0.2, fill_with_other_class_p=0.25, |
|
dont_do_if_covers_more_than_X_percent=0.25, p_per_label=1): |
|
""" |
|
:param dont_do_if_covers_more_than_X_percent: dont_do_if_covers_more_than_X_percent=0.25 is 25\%! |
|
:param channel_idx: can be list or int |
|
:param key: |
|
""" |
|
self.p_per_label = p_per_label |
|
self.dont_do_if_covers_more_than_X_percent = dont_do_if_covers_more_than_X_percent |
|
self.fill_with_other_class_p = fill_with_other_class_p |
|
self.p_per_sample = p_per_sample |
|
self.key = key |
|
if not isinstance(channel_idx, (list, tuple)): |
|
channel_idx = [channel_idx] |
|
self.channel_idx = channel_idx |
|
|
|
def __call__(self, **data_dict): |
|
data = data_dict.get(self.key) |
|
for b in range(data.shape[0]): |
|
if np.random.uniform() < self.p_per_sample: |
|
for c in self.channel_idx: |
|
if np.random.uniform() < self.p_per_label: |
|
workon = np.copy(data[b, c]) |
|
num_voxels = np.prod(workon.shape, dtype=np.uint64) |
|
lab, num_comp = label(workon, return_num=True) |
|
if num_comp > 0: |
|
component_ids = [] |
|
component_sizes = [] |
|
for i in range(1, num_comp + 1): |
|
component_ids.append(i) |
|
component_sizes.append(np.sum(lab == i)) |
|
component_ids = [i for i, j in zip(component_ids, component_sizes) if j < num_voxels*self.dont_do_if_covers_more_than_X_percent] |
|
|
|
|
|
|
|
if len(component_ids) > 0: |
|
random_component = np.random.choice(component_ids) |
|
data[b, c][lab == random_component] = 0 |
|
if np.random.uniform() < self.fill_with_other_class_p: |
|
other_ch = [i for i in self.channel_idx if i != c] |
|
if len(other_ch) > 0: |
|
other_class = np.random.choice(other_ch) |
|
data[b, other_class][lab == random_component] = 1 |
|
data_dict[self.key] = data |
|
return data_dict |
|
|
|
|
|
class MoveSegAsOneHotToData(AbstractTransform): |
|
def __init__(self, channel_id, all_seg_labels, key_origin="seg", key_target="data", remove_from_origin=True): |
|
self.remove_from_origin = remove_from_origin |
|
self.all_seg_labels = all_seg_labels |
|
self.key_target = key_target |
|
self.key_origin = key_origin |
|
self.channel_id = channel_id |
|
|
|
def __call__(self, **data_dict): |
|
origin = data_dict.get(self.key_origin) |
|
target = data_dict.get(self.key_target) |
|
seg = origin[:, self.channel_id:self.channel_id+1] |
|
seg_onehot = np.zeros((seg.shape[0], len(self.all_seg_labels), *seg.shape[2:]), dtype=seg.dtype) |
|
for i, l in enumerate(self.all_seg_labels): |
|
seg_onehot[:, i][seg[:, 0] == l] = 1 |
|
target = np.concatenate((target, seg_onehot), 1) |
|
data_dict[self.key_target] = target |
|
|
|
if self.remove_from_origin: |
|
remaining_channels = [i for i in range(origin.shape[1]) if i != self.channel_id] |
|
origin = origin[:, remaining_channels] |
|
data_dict[self.key_origin] = origin |
|
return data_dict |
|
|
|
|
|
class ApplyRandomBinaryOperatorTransform(AbstractTransform): |
|
def __init__(self, channel_idx, p_per_sample=0.3, any_of_these=(binary_dilation, binary_erosion, binary_closing, |
|
binary_opening), |
|
key="data", strel_size=(1, 10), p_per_label=1): |
|
self.p_per_label = p_per_label |
|
self.strel_size = strel_size |
|
self.key = key |
|
self.any_of_these = any_of_these |
|
self.p_per_sample = p_per_sample |
|
|
|
assert not isinstance(channel_idx, tuple), "bäh" |
|
|
|
if not isinstance(channel_idx, list): |
|
channel_idx = [channel_idx] |
|
self.channel_idx = channel_idx |
|
|
|
def __call__(self, **data_dict): |
|
data = data_dict.get(self.key) |
|
for b in range(data.shape[0]): |
|
if np.random.uniform() < self.p_per_sample: |
|
ch = deepcopy(self.channel_idx) |
|
np.random.shuffle(ch) |
|
for c in ch: |
|
if np.random.uniform() < self.p_per_label: |
|
operation = np.random.choice(self.any_of_these) |
|
selem = ball(np.random.uniform(*self.strel_size)) |
|
workon = np.copy(data[b, c]).astype(int) |
|
res = operation(workon, selem).astype(workon.dtype) |
|
data[b, c] = res |
|
|
|
|
|
|
|
|
|
other_ch = [i for i in ch if i != c] |
|
if len(other_ch) > 0: |
|
was_added_mask = (res - workon) > 0 |
|
for oc in other_ch: |
|
data[b, oc][was_added_mask] = 0 |
|
|
|
data_dict[self.key] = data |
|
return data_dict |
|
|
|
|
|
class ApplyRandomBinaryOperatorTransform2(AbstractTransform): |
|
def __init__(self, channel_idx, p_per_sample=0.3, p_per_label=0.3, any_of_these=(binary_dilation, binary_closing), |
|
key="data", strel_size=(1, 10)): |
|
""" |
|
2019_11_22: I have no idea what the purpose of this was... |
|
|
|
the same as above but here we should use only expanding operations. Expansions will replace other labels |
|
:param channel_idx: can be list or int |
|
:param p_per_sample: |
|
:param any_of_these: |
|
:param fill_diff_with_other_class: |
|
:param key: |
|
:param strel_size: |
|
""" |
|
self.strel_size = strel_size |
|
self.key = key |
|
self.any_of_these = any_of_these |
|
self.p_per_sample = p_per_sample |
|
self.p_per_label = p_per_label |
|
|
|
assert not isinstance(channel_idx, tuple), "bäh" |
|
|
|
if not isinstance(channel_idx, list): |
|
channel_idx = [channel_idx] |
|
self.channel_idx = channel_idx |
|
|
|
def __call__(self, **data_dict): |
|
data = data_dict.get(self.key) |
|
for b in range(data.shape[0]): |
|
if np.random.uniform() < self.p_per_sample: |
|
ch = deepcopy(self.channel_idx) |
|
np.random.shuffle(ch) |
|
for c in ch: |
|
if np.random.uniform() < self.p_per_label: |
|
operation = np.random.choice(self.any_of_these) |
|
selem = ball(np.random.uniform(*self.strel_size)) |
|
workon = np.copy(data[b, c]).astype(int) |
|
res = operation(workon, selem).astype(workon.dtype) |
|
data[b, c] = res |
|
|
|
|
|
|
|
|
|
other_ch = [i for i in ch if i != c] |
|
if len(other_ch) > 0: |
|
was_added_mask = (res - workon) > 0 |
|
for oc in other_ch: |
|
data[b, oc][was_added_mask] = 0 |
|
|
|
data_dict[self.key] = data |
|
return data_dict |
|
|