File size: 5,749 Bytes
2ca2f68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import monai
import torch
import itk
import numpy as np
import glob
import os


def path_to_id(path):
    return os.path.basename(path).split('.')[0]


def split_data(img_path, seg_path, num_seg):
    total_img_paths = []
    total_seg_paths = []
    for i in sorted(glob.glob(img_path + '/*.nii.gz')):
        total_img_paths.append(i)

    for j in sorted(glob.glob(seg_path + '/*.nii.gz')):
        total_seg_paths.append(j)

    np.random.shuffle(total_img_paths)
    num_train = int(round(len(total_seg_paths)*0.8))
    num_test = len(total_seg_paths) - num_train
    seg_train = total_seg_paths[:num_train]
    seg_test = total_seg_paths[num_train:]
    img_train = []
    img_test = []
    test = []
    train = []
    img_ids = list(map(path_to_id, total_img_paths))
    img_ids1 = img_ids
    total_img_paths1 = total_img_paths
    seg_ids_test = map(path_to_id, seg_test)
    seg_ids_train = map(path_to_id, seg_train)
    for seg_index, seg_id in enumerate(seg_ids_test):
        data_item = {}
        assert seg_id in img_ids
        img_test.append(total_img_paths[img_ids.index(seg_id)])
        data_item['img'] = total_img_paths[img_ids.index(seg_id)]
        total_img_paths1.pop(img_ids1.index(seg_id))
        img_ids1.pop(img_ids1.index(seg_id))
        data_item['seg'] = seg_test[seg_index]
        test.append(data_item)

    img_train = total_img_paths1
    np.random.shuffle(seg_train)
    if num_seg < len(seg_train):    
        seg_train_available = seg_train[:num_seg]
    else:
        seg_train_available = seg_train
    seg_ids = list(map(path_to_id, seg_train_available))
    img_ids = map(path_to_id, img_train)
    for img_index, img_id in enumerate(img_ids):
        data_item = {'img': img_train[img_index]}
        if img_id in seg_ids:
            data_item['seg'] = seg_train_available[seg_ids.index(img_id)]
        train.append(data_item)

    num_train = len(img_train)
    return train, test, num_train, num_test


def load_seg_dataset(train, valid):
    transform_seg_available = monai.transforms.Compose(
        transforms=[
            monai.transforms.LoadImageD(keys=['img', 'seg'], image_only=True),
            monai.transforms.AddChannelD(keys=['img', 'seg']),
            monai.transforms.SpacingD(keys=['img', 'seg'], pixdim=(1., 1., 1.), mode=('trilinear', 'nearest')),
            monai.transforms.ToTensorD(keys=['img', 'seg'])
        ]
    )
    itk.ProcessObject.SetGlobalWarningDisplay(False)
    dataset_seg_available_train = monai.data.CacheDataset(
        data=train,
        transform=transform_seg_available,
        cache_num=16,
        hash_as_key=True
    )

    dataset_seg_available_valid = monai.data.CacheDataset(
        data=valid,
        transform=transform_seg_available,
        cache_num=16,
        hash_as_key=True
    )
    return dataset_seg_available_train, dataset_seg_available_valid


def load_reg_dataset(train, valid):
    transform_pair = monai.transforms.Compose(
        transforms=[
            monai.transforms.LoadImageD(
                keys=['img1', 'seg1', 'img2', 'seg2'], image_only=True, allow_missing_keys=True),
            monai.transforms.ToTensorD(
                keys=['img1', 'seg1', 'img2', 'seg2'], allow_missing_keys=True),
            monai.transforms.AddChannelD(
                keys=['img1', 'seg1', 'img2', 'seg2'], allow_missing_keys=True),
            monai.transforms.SpacingD(keys=['img1', 'seg1', 'img2', 'seg2'], pixdim=(1., 1., 1.), mode=(
                'trilinear', 'nearest', 'trilinear', 'nearest'), allow_missing_keys=True),
            monai.transforms.ConcatItemsD(
                keys=['img1', 'img2'], name='img12', dim=0),
            monai.transforms.DeleteItemsD(keys=['img1', 'img2'])
        ]
    )
    dataset_pairs_train_subdivided = {
        seg_availability: monai.data.CacheDataset(
            data=data_list,
            transform=transform_pair,
            cache_num=32,
            hash_as_key=True
        )
        for seg_availability, data_list in train.items()
    }

    dataset_pairs_valid_subdivided = {
        seg_availability: monai.data.CacheDataset(
            data=data_list,
            transform=transform_pair,
            cache_num=32,
            hash_as_key=True
        )
        for seg_availability, data_list in valid.items()
    }
    return dataset_pairs_train_subdivided, dataset_pairs_valid_subdivided


def take_data_pairs(data, symmetric=True):
    """Given a list of dicts that have keys for an image and maybe a segmentation,
    return a list of dicts corresponding to *pairs* of images and maybe segmentations.
    Pairs consisting of a repeated image are not included.
    If symmetric is set to True, then for each pair that is included, its reverse is also included"""
    data_pairs = []
    for i in range(len(data)):
        j_limit = len(data) if symmetric else i
        for j in range(j_limit):
            if j == i:
                continue
            d1 = data[i]
            d2 = data[j]
            pair = {
                'img1': d1['img'],
                'img2': d2['img']
            }
            if 'seg' in d1.keys():
                pair['seg1'] = d1['seg']
            if 'seg' in d2.keys():
                pair['seg2'] = d2['seg']
            data_pairs.append(pair)
    return data_pairs


def subdivide_list_of_data_pairs(data_pairs_list):
    out_dict = {'00': [], '01': [], '10': [], '11': []}
    for d in data_pairs_list:
        if 'seg1' in d.keys() and 'seg2' in d.keys():
            out_dict['11'].append(d)
        elif 'seg1' in d.keys():
            out_dict['10'].append(d)
        elif 'seg2' in d.keys():
            out_dict['01'].append(d)
        else:
            out_dict['00'].append(d)
    return out_dict