ho11laqe's picture
init
ecf08bc
raw
history blame
No virus
9.03 kB
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from batchgenerators.transforms.abstract_transforms import AbstractTransform
from skimage.morphology import label, ball
from skimage.morphology.binary import binary_erosion, binary_dilation, binary_closing, binary_opening
import numpy as np
class RemoveRandomConnectedComponentFromOneHotEncodingTransform(AbstractTransform):
def __init__(self, channel_idx, key="data", p_per_sample=0.2, fill_with_other_class_p=0.25,
dont_do_if_covers_more_than_X_percent=0.25, p_per_label=1):
"""
:param dont_do_if_covers_more_than_X_percent: dont_do_if_covers_more_than_X_percent=0.25 is 25\%!
:param channel_idx: can be list or int
:param key:
"""
self.p_per_label = p_per_label
self.dont_do_if_covers_more_than_X_percent = dont_do_if_covers_more_than_X_percent
self.fill_with_other_class_p = fill_with_other_class_p
self.p_per_sample = p_per_sample
self.key = key
if not isinstance(channel_idx, (list, tuple)):
channel_idx = [channel_idx]
self.channel_idx = channel_idx
def __call__(self, **data_dict):
data = data_dict.get(self.key)
for b in range(data.shape[0]):
if np.random.uniform() < self.p_per_sample:
for c in self.channel_idx:
if np.random.uniform() < self.p_per_label:
workon = np.copy(data[b, c])
num_voxels = np.prod(workon.shape, dtype=np.uint64)
lab, num_comp = label(workon, return_num=True)
if num_comp > 0:
component_ids = []
component_sizes = []
for i in range(1, num_comp + 1):
component_ids.append(i)
component_sizes.append(np.sum(lab == i))
component_ids = [i for i, j in zip(component_ids, component_sizes) if j < num_voxels*self.dont_do_if_covers_more_than_X_percent]
#_ = component_ids.pop(np.argmax(component_sizes))
#else:
# component_ids = list(range(1, num_comp + 1))
if len(component_ids) > 0:
random_component = np.random.choice(component_ids)
data[b, c][lab == random_component] = 0
if np.random.uniform() < self.fill_with_other_class_p:
other_ch = [i for i in self.channel_idx if i != c]
if len(other_ch) > 0:
other_class = np.random.choice(other_ch)
data[b, other_class][lab == random_component] = 1
data_dict[self.key] = data
return data_dict
class MoveSegAsOneHotToData(AbstractTransform):
def __init__(self, channel_id, all_seg_labels, key_origin="seg", key_target="data", remove_from_origin=True):
self.remove_from_origin = remove_from_origin
self.all_seg_labels = all_seg_labels
self.key_target = key_target
self.key_origin = key_origin
self.channel_id = channel_id
def __call__(self, **data_dict):
origin = data_dict.get(self.key_origin)
target = data_dict.get(self.key_target)
seg = origin[:, self.channel_id:self.channel_id+1]
seg_onehot = np.zeros((seg.shape[0], len(self.all_seg_labels), *seg.shape[2:]), dtype=seg.dtype)
for i, l in enumerate(self.all_seg_labels):
seg_onehot[:, i][seg[:, 0] == l] = 1
target = np.concatenate((target, seg_onehot), 1)
data_dict[self.key_target] = target
if self.remove_from_origin:
remaining_channels = [i for i in range(origin.shape[1]) if i != self.channel_id]
origin = origin[:, remaining_channels]
data_dict[self.key_origin] = origin
return data_dict
class ApplyRandomBinaryOperatorTransform(AbstractTransform):
def __init__(self, channel_idx, p_per_sample=0.3, any_of_these=(binary_dilation, binary_erosion, binary_closing,
binary_opening),
key="data", strel_size=(1, 10), p_per_label=1):
self.p_per_label = p_per_label
self.strel_size = strel_size
self.key = key
self.any_of_these = any_of_these
self.p_per_sample = p_per_sample
assert not isinstance(channel_idx, tuple), "bäh"
if not isinstance(channel_idx, list):
channel_idx = [channel_idx]
self.channel_idx = channel_idx
def __call__(self, **data_dict):
data = data_dict.get(self.key)
for b in range(data.shape[0]):
if np.random.uniform() < self.p_per_sample:
ch = deepcopy(self.channel_idx)
np.random.shuffle(ch)
for c in ch:
if np.random.uniform() < self.p_per_label:
operation = np.random.choice(self.any_of_these)
selem = ball(np.random.uniform(*self.strel_size))
workon = np.copy(data[b, c]).astype(int)
res = operation(workon, selem).astype(workon.dtype)
data[b, c] = res
# if class was added, we need to remove it in ALL other channels to keep one hot encoding
# properties
# we modify data
other_ch = [i for i in ch if i != c]
if len(other_ch) > 0:
was_added_mask = (res - workon) > 0
for oc in other_ch:
data[b, oc][was_added_mask] = 0
# if class was removed, leave it at background
data_dict[self.key] = data
return data_dict
class ApplyRandomBinaryOperatorTransform2(AbstractTransform):
def __init__(self, channel_idx, p_per_sample=0.3, p_per_label=0.3, any_of_these=(binary_dilation, binary_closing),
key="data", strel_size=(1, 10)):
"""
2019_11_22: I have no idea what the purpose of this was...
the same as above but here we should use only expanding operations. Expansions will replace other labels
:param channel_idx: can be list or int
:param p_per_sample:
:param any_of_these:
:param fill_diff_with_other_class:
:param key:
:param strel_size:
"""
self.strel_size = strel_size
self.key = key
self.any_of_these = any_of_these
self.p_per_sample = p_per_sample
self.p_per_label = p_per_label
assert not isinstance(channel_idx, tuple), "bäh"
if not isinstance(channel_idx, list):
channel_idx = [channel_idx]
self.channel_idx = channel_idx
def __call__(self, **data_dict):
data = data_dict.get(self.key)
for b in range(data.shape[0]):
if np.random.uniform() < self.p_per_sample:
ch = deepcopy(self.channel_idx)
np.random.shuffle(ch)
for c in ch:
if np.random.uniform() < self.p_per_label:
operation = np.random.choice(self.any_of_these)
selem = ball(np.random.uniform(*self.strel_size))
workon = np.copy(data[b, c]).astype(int)
res = operation(workon, selem).astype(workon.dtype)
data[b, c] = res
# if class was added, we need to remove it in ALL other channels to keep one hot encoding
# properties
# we modify data
other_ch = [i for i in ch if i != c]
if len(other_ch) > 0:
was_added_mask = (res - workon) > 0
for oc in other_ch:
data[b, oc][was_added_mask] = 0
# if class was removed, leave it at backgound
data_dict[self.key] = data
return data_dict