|
import monai |
|
import torch |
|
import itk |
|
import numpy as np |
|
import glob |
|
import os |
|
|
|
|
|
def path_to_id(path): |
|
return os.path.basename(path).split('.')[0] |
|
|
|
|
|
def split_data(img_path, seg_path, num_seg): |
|
total_img_paths = [] |
|
total_seg_paths = [] |
|
for i in sorted(glob.glob(img_path + '/*.nii.gz')): |
|
total_img_paths.append(i) |
|
|
|
for j in sorted(glob.glob(seg_path + '/*.nii.gz')): |
|
total_seg_paths.append(j) |
|
|
|
np.random.shuffle(total_img_paths) |
|
num_train = int(round(len(total_seg_paths)*0.8)) |
|
num_test = len(total_seg_paths) - num_train |
|
seg_train = total_seg_paths[:num_train] |
|
seg_test = total_seg_paths[num_train:] |
|
img_train = [] |
|
img_test = [] |
|
test = [] |
|
train = [] |
|
img_ids = list(map(path_to_id, total_img_paths)) |
|
img_ids1 = img_ids |
|
total_img_paths1 = total_img_paths |
|
seg_ids_test = map(path_to_id, seg_test) |
|
seg_ids_train = map(path_to_id, seg_train) |
|
for seg_index, seg_id in enumerate(seg_ids_test): |
|
data_item = {} |
|
assert seg_id in img_ids |
|
img_test.append(total_img_paths[img_ids.index(seg_id)]) |
|
data_item['img'] = total_img_paths[img_ids.index(seg_id)] |
|
total_img_paths1.pop(img_ids1.index(seg_id)) |
|
img_ids1.pop(img_ids1.index(seg_id)) |
|
data_item['seg'] = seg_test[seg_index] |
|
test.append(data_item) |
|
|
|
img_train = total_img_paths1 |
|
np.random.shuffle(seg_train) |
|
if num_seg < len(seg_train): |
|
seg_train_available = seg_train[:num_seg] |
|
else: |
|
seg_train_available = seg_train |
|
seg_ids = list(map(path_to_id, seg_train_available)) |
|
img_ids = map(path_to_id, img_train) |
|
for img_index, img_id in enumerate(img_ids): |
|
data_item = {'img': img_train[img_index]} |
|
if img_id in seg_ids: |
|
data_item['seg'] = seg_train_available[seg_ids.index(img_id)] |
|
train.append(data_item) |
|
|
|
num_train = len(img_train) |
|
return train, test, num_train, num_test |
|
|
|
|
|
def load_seg_dataset(train, valid): |
|
transform_seg_available = monai.transforms.Compose( |
|
transforms=[ |
|
monai.transforms.LoadImageD(keys=['img', 'seg'], image_only=True), |
|
monai.transforms.AddChannelD(keys=['img', 'seg']), |
|
monai.transforms.SpacingD(keys=['img', 'seg'], pixdim=(1., 1., 1.), mode=('trilinear', 'nearest')), |
|
monai.transforms.ToTensorD(keys=['img', 'seg']) |
|
] |
|
) |
|
itk.ProcessObject.SetGlobalWarningDisplay(False) |
|
dataset_seg_available_train = monai.data.CacheDataset( |
|
data=train, |
|
transform=transform_seg_available, |
|
cache_num=16, |
|
hash_as_key=True |
|
) |
|
|
|
dataset_seg_available_valid = monai.data.CacheDataset( |
|
data=valid, |
|
transform=transform_seg_available, |
|
cache_num=16, |
|
hash_as_key=True |
|
) |
|
return dataset_seg_available_train, dataset_seg_available_valid |
|
|
|
|
|
def load_reg_dataset(train, valid): |
|
transform_pair = monai.transforms.Compose( |
|
transforms=[ |
|
monai.transforms.LoadImageD( |
|
keys=['img1', 'seg1', 'img2', 'seg2'], image_only=True, allow_missing_keys=True), |
|
monai.transforms.ToTensorD( |
|
keys=['img1', 'seg1', 'img2', 'seg2'], allow_missing_keys=True), |
|
monai.transforms.AddChannelD( |
|
keys=['img1', 'seg1', 'img2', 'seg2'], allow_missing_keys=True), |
|
monai.transforms.SpacingD(keys=['img1', 'seg1', 'img2', 'seg2'], pixdim=(1., 1., 1.), mode=( |
|
'trilinear', 'nearest', 'trilinear', 'nearest'), allow_missing_keys=True), |
|
monai.transforms.ConcatItemsD( |
|
keys=['img1', 'img2'], name='img12', dim=0), |
|
monai.transforms.DeleteItemsD(keys=['img1', 'img2']) |
|
] |
|
) |
|
dataset_pairs_train_subdivided = { |
|
seg_availability: monai.data.CacheDataset( |
|
data=data_list, |
|
transform=transform_pair, |
|
cache_num=32, |
|
hash_as_key=True |
|
) |
|
for seg_availability, data_list in train.items() |
|
} |
|
|
|
dataset_pairs_valid_subdivided = { |
|
seg_availability: monai.data.CacheDataset( |
|
data=data_list, |
|
transform=transform_pair, |
|
cache_num=32, |
|
hash_as_key=True |
|
) |
|
for seg_availability, data_list in valid.items() |
|
} |
|
return dataset_pairs_train_subdivided, dataset_pairs_valid_subdivided |
|
|
|
|
|
def take_data_pairs(data, symmetric=True): |
|
"""Given a list of dicts that have keys for an image and maybe a segmentation, |
|
return a list of dicts corresponding to *pairs* of images and maybe segmentations. |
|
Pairs consisting of a repeated image are not included. |
|
If symmetric is set to True, then for each pair that is included, its reverse is also included""" |
|
data_pairs = [] |
|
for i in range(len(data)): |
|
j_limit = len(data) if symmetric else i |
|
for j in range(j_limit): |
|
if j == i: |
|
continue |
|
d1 = data[i] |
|
d2 = data[j] |
|
pair = { |
|
'img1': d1['img'], |
|
'img2': d2['img'] |
|
} |
|
if 'seg' in d1.keys(): |
|
pair['seg1'] = d1['seg'] |
|
if 'seg' in d2.keys(): |
|
pair['seg2'] = d2['seg'] |
|
data_pairs.append(pair) |
|
return data_pairs |
|
|
|
|
|
def subdivide_list_of_data_pairs(data_pairs_list): |
|
out_dict = {'00': [], '01': [], '10': [], '11': []} |
|
for d in data_pairs_list: |
|
if 'seg1' in d.keys() and 'seg2' in d.keys(): |
|
out_dict['11'].append(d) |
|
elif 'seg1' in d.keys(): |
|
out_dict['10'].append(d) |
|
elif 'seg2' in d.keys(): |
|
out_dict['01'].append(d) |
|
else: |
|
out_dict['00'].append(d) |
|
return out_dict |
|
|