import monai import torch import itk import numpy as np import glob import os def path_to_id(path): return os.path.basename(path).split('.')[0] def split_data(img_path, seg_path, num_seg): total_img_paths = [] total_seg_paths = [] for i in sorted(glob.glob(img_path + '/*.nii.gz')): total_img_paths.append(i) for j in sorted(glob.glob(seg_path + '/*.nii.gz')): total_seg_paths.append(j) np.random.shuffle(total_img_paths) num_train = int(round(len(total_seg_paths)*0.8)) num_test = len(total_seg_paths) - num_train seg_train = total_seg_paths[:num_train] seg_test = total_seg_paths[num_train:] img_train = [] img_test = [] test = [] train = [] img_ids = list(map(path_to_id, total_img_paths)) img_ids1 = img_ids total_img_paths1 = total_img_paths seg_ids_test = map(path_to_id, seg_test) seg_ids_train = map(path_to_id, seg_train) for seg_index, seg_id in enumerate(seg_ids_test): data_item = {} assert seg_id in img_ids img_test.append(total_img_paths[img_ids.index(seg_id)]) data_item['img'] = total_img_paths[img_ids.index(seg_id)] total_img_paths1.pop(img_ids1.index(seg_id)) img_ids1.pop(img_ids1.index(seg_id)) data_item['seg'] = seg_test[seg_index] test.append(data_item) img_train = total_img_paths1 np.random.shuffle(seg_train) if num_seg < len(seg_train): seg_train_available = seg_train[:num_seg] else: seg_train_available = seg_train seg_ids = list(map(path_to_id, seg_train_available)) img_ids = map(path_to_id, img_train) for img_index, img_id in enumerate(img_ids): data_item = {'img': img_train[img_index]} if img_id in seg_ids: data_item['seg'] = seg_train_available[seg_ids.index(img_id)] train.append(data_item) num_train = len(img_train) return train, test, num_train, num_test def load_seg_dataset(train, valid): transform_seg_available = monai.transforms.Compose( transforms=[ monai.transforms.LoadImageD(keys=['img', 'seg'], image_only=True), monai.transforms.AddChannelD(keys=['img', 'seg']), monai.transforms.SpacingD(keys=['img', 'seg'], pixdim=(1., 1., 1.), mode=('trilinear', 'nearest')), monai.transforms.ToTensorD(keys=['img', 'seg']) ] ) itk.ProcessObject.SetGlobalWarningDisplay(False) dataset_seg_available_train = monai.data.CacheDataset( data=train, transform=transform_seg_available, cache_num=16, hash_as_key=True ) dataset_seg_available_valid = monai.data.CacheDataset( data=valid, transform=transform_seg_available, cache_num=16, hash_as_key=True ) return dataset_seg_available_train, dataset_seg_available_valid def load_reg_dataset(train, valid): transform_pair = monai.transforms.Compose( transforms=[ monai.transforms.LoadImageD( keys=['img1', 'seg1', 'img2', 'seg2'], image_only=True, allow_missing_keys=True), monai.transforms.ToTensorD( keys=['img1', 'seg1', 'img2', 'seg2'], allow_missing_keys=True), monai.transforms.AddChannelD( keys=['img1', 'seg1', 'img2', 'seg2'], allow_missing_keys=True), monai.transforms.SpacingD(keys=['img1', 'seg1', 'img2', 'seg2'], pixdim=(1., 1., 1.), mode=( 'trilinear', 'nearest', 'trilinear', 'nearest'), allow_missing_keys=True), monai.transforms.ConcatItemsD( keys=['img1', 'img2'], name='img12', dim=0), monai.transforms.DeleteItemsD(keys=['img1', 'img2']) ] ) dataset_pairs_train_subdivided = { seg_availability: monai.data.CacheDataset( data=data_list, transform=transform_pair, cache_num=32, hash_as_key=True ) for seg_availability, data_list in train.items() } dataset_pairs_valid_subdivided = { seg_availability: monai.data.CacheDataset( data=data_list, transform=transform_pair, cache_num=32, hash_as_key=True ) for seg_availability, data_list in valid.items() } return dataset_pairs_train_subdivided, dataset_pairs_valid_subdivided def take_data_pairs(data, symmetric=True): """Given a list of dicts that have keys for an image and maybe a segmentation, return a list of dicts corresponding to *pairs* of images and maybe segmentations. Pairs consisting of a repeated image are not included. If symmetric is set to True, then for each pair that is included, its reverse is also included""" data_pairs = [] for i in range(len(data)): j_limit = len(data) if symmetric else i for j in range(j_limit): if j == i: continue d1 = data[i] d2 = data[j] pair = { 'img1': d1['img'], 'img2': d2['img'] } if 'seg' in d1.keys(): pair['seg1'] = d1['seg'] if 'seg' in d2.keys(): pair['seg2'] = d2['seg'] data_pairs.append(pair) return data_pairs def subdivide_list_of_data_pairs(data_pairs_list): out_dict = {'00': [], '01': [], '10': [], '11': []} for d in data_pairs_list: if 'seg1' in d.keys() and 'seg2' in d.keys(): out_dict['11'].append(d) elif 'seg1' in d.keys(): out_dict['10'].append(d) elif 'seg2' in d.keys(): out_dict['01'].append(d) else: out_dict['00'].append(d) return out_dict