File size: 1,940 Bytes
5961189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np
import torch
from monai.transforms import InvertibleTransform
from monai.transforms.transform import MapTransform


class ConcatImages(MapTransform, InvertibleTransform):
    def __init__(self, keys_merge, keys_out, allow_missing_keys=True):
        self.keys_merge = keys_merge
        self.keys_out = keys_out
        self.key_target_meta = keys_merge[0] + "_meta_dict"
        self.allow_missing_keys = allow_missing_keys

    def __call__(self, data):
        if isinstance(data, list):
            for data_row in data:
                data_row[self.keys_out] = np.concatenate([data_row[key] for key in self.keys_merge])
                data_row[self.keys_out + "_meta_dict"] = data_row[self.key_target_meta]
        else:
            data[self.keys_out] = np.concatenate([data[key] for key in self.keys_merge])
            data[self.keys_out + "_meta_dict"] = data[self.key_target_meta]
        return data

    def inverse(self, data):
        return data


class MergeClassesd(MapTransform):
    def __call__(self, data):
        for key in self.keys:
            if key in data:
                num_classes = data[key].size(-4)
                device = data[key].device
                merged = None
                for channel in data[key].squeeze() * torch.tensor(list(range(num_classes)), device=device).view(
                    -1, 1, 1, 1
                ):
                    imgvol = channel
                    if merged is not None:
                        merged = merged + imgvol * ~((merged != 0) & (imgvol != 0))
                    else:
                        merged = imgvol
                data[key] = merged.unsqueeze(0)
            elif not self.allow_missing_keys:
                raise KeyError(
                    f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data"
                    " and allow_missing_keys==False."
                )
        return data