Paul Engstler
Initial commit
92f0e98
raw
history blame
No virus
2.76 kB
import numpy as np
from typing import List, Tuple
from monai.transforms import Compose, AddChannelD, MaskIntensityD, DeleteItemsD, CropForegroundD, ResizeD
class SelectMaskByLevelD:
"""
Selects a mask segment from a mask image based on a given level index. May also
be applied to a single channel.
"""
def __init__(self, mask_key: str, level_idx_key: str):
self.mask_key = mask_key
self.level_idx_key = level_idx_key
def __call__(self, data):
d = dict(data)
mask = np.zeros_like(d[self.mask_key])
mask[d[self.mask_key] == d[self.level_idx_key]] = 1
d[self.mask_key] = mask
return d
def get_mask_transform(hparams, loaded_keys: List[str], level_idx_key='level_idx') -> Tuple[Compose, List[str]]:
"""
Depending on the configuration values for 'MASK', the transform returned by this method does one of the following:
- nothing ('none')
- applies the mask of the critical vertebra to the image ('apply')
- applies the mask of all visible vertebrae to the image ('apply_all')
- loads the mask into the 'mask' key s.t. it will later be stacked with the image ('channel')
- crop the image to the critical vertebra and upsample it ('crop')
"""
if hparams.mask == 'none':
return Compose([]), loaded_keys
assert len(loaded_keys) == 2
image_key, mask_key = loaded_keys
if hparams.mask == 'apply':
return Compose([
# only select relevant vertebra
SelectMaskByLevelD(mask_key=mask_key, level_idx_key=level_idx_key),
# apply mask
MaskIntensityD(keys=image_key, mask_key=mask_key),
# once the mask is applied, release it
DeleteItemsD(keys=mask_key),
]), [image_key]
elif hparams.mask == 'apply_all':
return Compose([
# keeps all vertebra in the mask
# apply mask
MaskIntensityD(keys=image_key, mask_key=mask_key),
# once the mask is applied, release it
DeleteItemsD(keys=mask_key),
]), [image_key]
elif hparams.mask == 'channel':
return Compose([
SelectMaskByLevelD(mask_key=mask_key, level_idx_key=level_idx_key),
]), loaded_keys
elif hparams.mask == 'crop':
# TODO CropForegroundD ignores one spatial dimension, thus not truly cropping
return Compose([
SelectMaskByLevelD(mask_key=mask_key, level_idx_key=level_idx_key),
CropForegroundD(keys=image_key, source_key=mask_key, margin=2),
DeleteItemsD(keys=mask_key),
AddChannelD(keys=image_key),
ResizeD(keys=image_key, spatial_size=[hparams.input_size] * hparams.input_dim, mode='trilinear'),
]), [image_key]