Chris Xiao
upload files
2ca2f68
raw
history blame
5.75 kB
import monai
import torch
import itk
import numpy as np
import glob
import os
def path_to_id(path):
return os.path.basename(path).split('.')[0]
def split_data(img_path, seg_path, num_seg):
total_img_paths = []
total_seg_paths = []
for i in sorted(glob.glob(img_path + '/*.nii.gz')):
total_img_paths.append(i)
for j in sorted(glob.glob(seg_path + '/*.nii.gz')):
total_seg_paths.append(j)
np.random.shuffle(total_img_paths)
num_train = int(round(len(total_seg_paths)*0.8))
num_test = len(total_seg_paths) - num_train
seg_train = total_seg_paths[:num_train]
seg_test = total_seg_paths[num_train:]
img_train = []
img_test = []
test = []
train = []
img_ids = list(map(path_to_id, total_img_paths))
img_ids1 = img_ids
total_img_paths1 = total_img_paths
seg_ids_test = map(path_to_id, seg_test)
seg_ids_train = map(path_to_id, seg_train)
for seg_index, seg_id in enumerate(seg_ids_test):
data_item = {}
assert seg_id in img_ids
img_test.append(total_img_paths[img_ids.index(seg_id)])
data_item['img'] = total_img_paths[img_ids.index(seg_id)]
total_img_paths1.pop(img_ids1.index(seg_id))
img_ids1.pop(img_ids1.index(seg_id))
data_item['seg'] = seg_test[seg_index]
test.append(data_item)
img_train = total_img_paths1
np.random.shuffle(seg_train)
if num_seg < len(seg_train):
seg_train_available = seg_train[:num_seg]
else:
seg_train_available = seg_train
seg_ids = list(map(path_to_id, seg_train_available))
img_ids = map(path_to_id, img_train)
for img_index, img_id in enumerate(img_ids):
data_item = {'img': img_train[img_index]}
if img_id in seg_ids:
data_item['seg'] = seg_train_available[seg_ids.index(img_id)]
train.append(data_item)
num_train = len(img_train)
return train, test, num_train, num_test
def load_seg_dataset(train, valid):
transform_seg_available = monai.transforms.Compose(
transforms=[
monai.transforms.LoadImageD(keys=['img', 'seg'], image_only=True),
monai.transforms.AddChannelD(keys=['img', 'seg']),
monai.transforms.SpacingD(keys=['img', 'seg'], pixdim=(1., 1., 1.), mode=('trilinear', 'nearest')),
monai.transforms.ToTensorD(keys=['img', 'seg'])
]
)
itk.ProcessObject.SetGlobalWarningDisplay(False)
dataset_seg_available_train = monai.data.CacheDataset(
data=train,
transform=transform_seg_available,
cache_num=16,
hash_as_key=True
)
dataset_seg_available_valid = monai.data.CacheDataset(
data=valid,
transform=transform_seg_available,
cache_num=16,
hash_as_key=True
)
return dataset_seg_available_train, dataset_seg_available_valid
def load_reg_dataset(train, valid):
transform_pair = monai.transforms.Compose(
transforms=[
monai.transforms.LoadImageD(
keys=['img1', 'seg1', 'img2', 'seg2'], image_only=True, allow_missing_keys=True),
monai.transforms.ToTensorD(
keys=['img1', 'seg1', 'img2', 'seg2'], allow_missing_keys=True),
monai.transforms.AddChannelD(
keys=['img1', 'seg1', 'img2', 'seg2'], allow_missing_keys=True),
monai.transforms.SpacingD(keys=['img1', 'seg1', 'img2', 'seg2'], pixdim=(1., 1., 1.), mode=(
'trilinear', 'nearest', 'trilinear', 'nearest'), allow_missing_keys=True),
monai.transforms.ConcatItemsD(
keys=['img1', 'img2'], name='img12', dim=0),
monai.transforms.DeleteItemsD(keys=['img1', 'img2'])
]
)
dataset_pairs_train_subdivided = {
seg_availability: monai.data.CacheDataset(
data=data_list,
transform=transform_pair,
cache_num=32,
hash_as_key=True
)
for seg_availability, data_list in train.items()
}
dataset_pairs_valid_subdivided = {
seg_availability: monai.data.CacheDataset(
data=data_list,
transform=transform_pair,
cache_num=32,
hash_as_key=True
)
for seg_availability, data_list in valid.items()
}
return dataset_pairs_train_subdivided, dataset_pairs_valid_subdivided
def take_data_pairs(data, symmetric=True):
"""Given a list of dicts that have keys for an image and maybe a segmentation,
return a list of dicts corresponding to *pairs* of images and maybe segmentations.
Pairs consisting of a repeated image are not included.
If symmetric is set to True, then for each pair that is included, its reverse is also included"""
data_pairs = []
for i in range(len(data)):
j_limit = len(data) if symmetric else i
for j in range(j_limit):
if j == i:
continue
d1 = data[i]
d2 = data[j]
pair = {
'img1': d1['img'],
'img2': d2['img']
}
if 'seg' in d1.keys():
pair['seg1'] = d1['seg']
if 'seg' in d2.keys():
pair['seg2'] = d2['seg']
data_pairs.append(pair)
return data_pairs
def subdivide_list_of_data_pairs(data_pairs_list):
out_dict = {'00': [], '01': [], '10': [], '11': []}
for d in data_pairs_list:
if 'seg1' in d.keys() and 'seg2' in d.keys():
out_dict['11'].append(d)
elif 'seg1' in d.keys():
out_dict['10'].append(d)
elif 'seg2' in d.keys():
out_dict['01'].append(d)
else:
out_dict['00'].append(d)
return out_dict