File size: 1,940 Bytes
5961189 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import numpy as np
import torch
from monai.transforms import InvertibleTransform
from monai.transforms.transform import MapTransform
class ConcatImages(MapTransform, InvertibleTransform):
def __init__(self, keys_merge, keys_out, allow_missing_keys=True):
self.keys_merge = keys_merge
self.keys_out = keys_out
self.key_target_meta = keys_merge[0] + "_meta_dict"
self.allow_missing_keys = allow_missing_keys
def __call__(self, data):
if isinstance(data, list):
for data_row in data:
data_row[self.keys_out] = np.concatenate([data_row[key] for key in self.keys_merge])
data_row[self.keys_out + "_meta_dict"] = data_row[self.key_target_meta]
else:
data[self.keys_out] = np.concatenate([data[key] for key in self.keys_merge])
data[self.keys_out + "_meta_dict"] = data[self.key_target_meta]
return data
def inverse(self, data):
return data
class MergeClassesd(MapTransform):
def __call__(self, data):
for key in self.keys:
if key in data:
num_classes = data[key].size(-4)
device = data[key].device
merged = None
for channel in data[key].squeeze() * torch.tensor(list(range(num_classes)), device=device).view(
-1, 1, 1, 1
):
imgvol = channel
if merged is not None:
merged = merged + imgvol * ~((merged != 0) & (imgvol != 0))
else:
merged = imgvol
data[key] = merged.unsqueeze(0)
elif not self.allow_missing_keys:
raise KeyError(
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data"
" and allow_missing_keys==False."
)
return data
|