katielink's picture
complete the model package
5961189
import numpy as np
import torch
from monai.transforms import InvertibleTransform
from monai.transforms.transform import MapTransform
class ConcatImages(MapTransform, InvertibleTransform):
def __init__(self, keys_merge, keys_out, allow_missing_keys=True):
self.keys_merge = keys_merge
self.keys_out = keys_out
self.key_target_meta = keys_merge[0] + "_meta_dict"
self.allow_missing_keys = allow_missing_keys
def __call__(self, data):
if isinstance(data, list):
for data_row in data:
data_row[self.keys_out] = np.concatenate([data_row[key] for key in self.keys_merge])
data_row[self.keys_out + "_meta_dict"] = data_row[self.key_target_meta]
else:
data[self.keys_out] = np.concatenate([data[key] for key in self.keys_merge])
data[self.keys_out + "_meta_dict"] = data[self.key_target_meta]
return data
def inverse(self, data):
return data
class MergeClassesd(MapTransform):
def __call__(self, data):
for key in self.keys:
if key in data:
num_classes = data[key].size(-4)
device = data[key].device
merged = None
for channel in data[key].squeeze() * torch.tensor(list(range(num_classes)), device=device).view(
-1, 1, 1, 1
):
imgvol = channel
if merged is not None:
merged = merged + imgvol * ~((merged != 0) & (imgvol != 0))
else:
merged = imgvol
data[key] = merged.unsqueeze(0)
elif not self.allow_missing_keys:
raise KeyError(
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data"
" and allow_missing_keys==False."
)
return data