|
import numpy as np |
|
import torch |
|
from monai.transforms import InvertibleTransform |
|
from monai.transforms.transform import MapTransform |
|
|
|
|
|
class ConcatImages(MapTransform, InvertibleTransform): |
|
def __init__(self, keys_merge, keys_out, allow_missing_keys=True): |
|
self.keys_merge = keys_merge |
|
self.keys_out = keys_out |
|
self.key_target_meta = keys_merge[0] + "_meta_dict" |
|
self.allow_missing_keys = allow_missing_keys |
|
|
|
def __call__(self, data): |
|
if isinstance(data, list): |
|
for data_row in data: |
|
data_row[self.keys_out] = np.concatenate([data_row[key] for key in self.keys_merge]) |
|
data_row[self.keys_out + "_meta_dict"] = data_row[self.key_target_meta] |
|
else: |
|
data[self.keys_out] = np.concatenate([data[key] for key in self.keys_merge]) |
|
data[self.keys_out + "_meta_dict"] = data[self.key_target_meta] |
|
return data |
|
|
|
def inverse(self, data): |
|
return data |
|
|
|
|
|
class MergeClassesd(MapTransform): |
|
def __call__(self, data): |
|
for key in self.keys: |
|
if key in data: |
|
num_classes = data[key].size(-4) |
|
device = data[key].device |
|
merged = None |
|
for channel in data[key].squeeze() * torch.tensor(list(range(num_classes)), device=device).view( |
|
-1, 1, 1, 1 |
|
): |
|
imgvol = channel |
|
if merged is not None: |
|
merged = merged + imgvol * ~((merged != 0) & (imgvol != 0)) |
|
else: |
|
merged = imgvol |
|
data[key] = merged.unsqueeze(0) |
|
elif not self.allow_missing_keys: |
|
raise KeyError( |
|
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data" |
|
" and allow_missing_keys==False." |
|
) |
|
return data |
|
|