File size: 5,749 Bytes
2ca2f68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import monai
import torch
import itk
import numpy as np
import glob
import os
def path_to_id(path):
return os.path.basename(path).split('.')[0]
def split_data(img_path, seg_path, num_seg):
total_img_paths = []
total_seg_paths = []
for i in sorted(glob.glob(img_path + '/*.nii.gz')):
total_img_paths.append(i)
for j in sorted(glob.glob(seg_path + '/*.nii.gz')):
total_seg_paths.append(j)
np.random.shuffle(total_img_paths)
num_train = int(round(len(total_seg_paths)*0.8))
num_test = len(total_seg_paths) - num_train
seg_train = total_seg_paths[:num_train]
seg_test = total_seg_paths[num_train:]
img_train = []
img_test = []
test = []
train = []
img_ids = list(map(path_to_id, total_img_paths))
img_ids1 = img_ids
total_img_paths1 = total_img_paths
seg_ids_test = map(path_to_id, seg_test)
seg_ids_train = map(path_to_id, seg_train)
for seg_index, seg_id in enumerate(seg_ids_test):
data_item = {}
assert seg_id in img_ids
img_test.append(total_img_paths[img_ids.index(seg_id)])
data_item['img'] = total_img_paths[img_ids.index(seg_id)]
total_img_paths1.pop(img_ids1.index(seg_id))
img_ids1.pop(img_ids1.index(seg_id))
data_item['seg'] = seg_test[seg_index]
test.append(data_item)
img_train = total_img_paths1
np.random.shuffle(seg_train)
if num_seg < len(seg_train):
seg_train_available = seg_train[:num_seg]
else:
seg_train_available = seg_train
seg_ids = list(map(path_to_id, seg_train_available))
img_ids = map(path_to_id, img_train)
for img_index, img_id in enumerate(img_ids):
data_item = {'img': img_train[img_index]}
if img_id in seg_ids:
data_item['seg'] = seg_train_available[seg_ids.index(img_id)]
train.append(data_item)
num_train = len(img_train)
return train, test, num_train, num_test
def load_seg_dataset(train, valid):
transform_seg_available = monai.transforms.Compose(
transforms=[
monai.transforms.LoadImageD(keys=['img', 'seg'], image_only=True),
monai.transforms.AddChannelD(keys=['img', 'seg']),
monai.transforms.SpacingD(keys=['img', 'seg'], pixdim=(1., 1., 1.), mode=('trilinear', 'nearest')),
monai.transforms.ToTensorD(keys=['img', 'seg'])
]
)
itk.ProcessObject.SetGlobalWarningDisplay(False)
dataset_seg_available_train = monai.data.CacheDataset(
data=train,
transform=transform_seg_available,
cache_num=16,
hash_as_key=True
)
dataset_seg_available_valid = monai.data.CacheDataset(
data=valid,
transform=transform_seg_available,
cache_num=16,
hash_as_key=True
)
return dataset_seg_available_train, dataset_seg_available_valid
def load_reg_dataset(train, valid):
transform_pair = monai.transforms.Compose(
transforms=[
monai.transforms.LoadImageD(
keys=['img1', 'seg1', 'img2', 'seg2'], image_only=True, allow_missing_keys=True),
monai.transforms.ToTensorD(
keys=['img1', 'seg1', 'img2', 'seg2'], allow_missing_keys=True),
monai.transforms.AddChannelD(
keys=['img1', 'seg1', 'img2', 'seg2'], allow_missing_keys=True),
monai.transforms.SpacingD(keys=['img1', 'seg1', 'img2', 'seg2'], pixdim=(1., 1., 1.), mode=(
'trilinear', 'nearest', 'trilinear', 'nearest'), allow_missing_keys=True),
monai.transforms.ConcatItemsD(
keys=['img1', 'img2'], name='img12', dim=0),
monai.transforms.DeleteItemsD(keys=['img1', 'img2'])
]
)
dataset_pairs_train_subdivided = {
seg_availability: monai.data.CacheDataset(
data=data_list,
transform=transform_pair,
cache_num=32,
hash_as_key=True
)
for seg_availability, data_list in train.items()
}
dataset_pairs_valid_subdivided = {
seg_availability: monai.data.CacheDataset(
data=data_list,
transform=transform_pair,
cache_num=32,
hash_as_key=True
)
for seg_availability, data_list in valid.items()
}
return dataset_pairs_train_subdivided, dataset_pairs_valid_subdivided
def take_data_pairs(data, symmetric=True):
"""Given a list of dicts that have keys for an image and maybe a segmentation,
return a list of dicts corresponding to *pairs* of images and maybe segmentations.
Pairs consisting of a repeated image are not included.
If symmetric is set to True, then for each pair that is included, its reverse is also included"""
data_pairs = []
for i in range(len(data)):
j_limit = len(data) if symmetric else i
for j in range(j_limit):
if j == i:
continue
d1 = data[i]
d2 = data[j]
pair = {
'img1': d1['img'],
'img2': d2['img']
}
if 'seg' in d1.keys():
pair['seg1'] = d1['seg']
if 'seg' in d2.keys():
pair['seg2'] = d2['seg']
data_pairs.append(pair)
return data_pairs
def subdivide_list_of_data_pairs(data_pairs_list):
out_dict = {'00': [], '01': [], '10': [], '11': []}
for d in data_pairs_list:
if 'seg1' in d.keys() and 'seg2' in d.keys():
out_dict['11'].append(d)
elif 'seg1' in d.keys():
out_dict['10'].append(d)
elif 'seg2' in d.keys():
out_dict['01'].append(d)
else:
out_dict['00'].append(d)
return out_dict
|