from natsort import natsorted import numpy as np from pathlib import Path import os from os.path import join from nnunet.dataset_conversion.utils import generate_dataset_json import SimpleITK as sitk import gc import multiprocessing as mp from functools import partial def preprocess_dataset(ribfrac_load_path, ribseg_load_path, dataset_save_path, pool): mask_load_path = join(ribseg_load_path, "labelsTr") train_image_save_path = join(dataset_save_path, "imagesTr") train_mask_save_path = join(dataset_save_path, "labelsTr") test_image_save_path = join(dataset_save_path, "imagesTs") test_labels_save_path = join(dataset_save_path, "labelsTs") Path(train_image_save_path).mkdir(parents=True, exist_ok=True) Path(train_mask_save_path).mkdir(parents=True, exist_ok=True) Path(test_image_save_path).mkdir(parents=True, exist_ok=True) Path(test_labels_save_path).mkdir(parents=True, exist_ok=True) mask_filenames = load_filenames(mask_load_path) pool.map(partial(preprocess_single, image_load_path=ribfrac_load_path), mask_filenames) def preprocess_single(filename, image_load_path): name = os.path.basename(filename) if "-cl.nii.gz" in name: return id = int(name.split("-")[0][7:]) image_set = "imagesTr" mask_set = "labelsTr" if id > 500: image_set = "imagesTs" mask_set = "labelsTs" image, _, _, _ = load_image(join(image_load_path, image_set, "RibFrac{}-image.nii.gz".format(id)), return_meta=True, is_seg=False) mask, spacing, _, _ = load_image(filename, return_meta=True, is_seg=True) save_image(join(dataset_save_path, image_set, "RibSeg_" + str(id).zfill(4) + "_0000.nii.gz"), image, spacing=spacing, is_seg=False) save_image(join(dataset_save_path, mask_set, "RibSeg_" + str(id).zfill(4) + ".nii.gz"), mask, spacing=spacing, is_seg=True) def load_filenames(img_dir, extensions=None): _img_dir = fix_path(img_dir) img_filenames = [] for file in os.listdir(_img_dir): if extensions is None or file.endswith(extensions): img_filenames.append(_img_dir + file) img_filenames = np.asarray(img_filenames) img_filenames = natsorted(img_filenames) return img_filenames def fix_path(path): if path[-1] != "/": path += "/" return path def load_image(filepath, return_meta=False, is_seg=False): image = sitk.ReadImage(filepath) image_np = sitk.GetArrayFromImage(image) if is_seg: image_np = np.rint(image_np) image_np = image_np.astype(np.int8) # In special cases segmentations can contain negative labels, so no np.uint8 if not return_meta: return image_np else: spacing = image.GetSpacing() keys = image.GetMetaDataKeys() header = {key:image.GetMetaData(key) for key in keys} affine = None # How do I get the affine transform with SimpleITK? With NiBabel it is just image.affine return image_np, spacing, affine, header def save_image(filename, image, spacing=None, affine=None, header=None, is_seg=False, mp_pool=None, free_mem=False): if is_seg: image = np.rint(image) image = image.astype(np.int8) # In special cases segmentations can contain negative labels, so no np.uint8 image = sitk.GetImageFromArray(image) if header is not None: [image.SetMetaData(key, header[key]) for key in header.keys()] if spacing is not None: image.SetSpacing(spacing) if affine is not None: pass # How do I set the affine transform with SimpleITK? With NiBabel it is just nib.Nifti1Image(img, affine=affine, header=header) if mp_pool is None: sitk.WriteImage(image, filename) if free_mem: del image gc.collect() else: mp_pool.apply_async(_save, args=(filename, image, free_mem,)) if free_mem: del image gc.collect() def _save(filename, image, free_mem): sitk.WriteImage(image, filename) if free_mem: del image gc.collect() if __name__ == "__main__": # Note: Due to a bug in SimpleITK 2.1.x a version of SimpleITK < 2.1.0 is required for loading images. Further, we can't copy the images and masks, but have to load them and resample both to the same spacing. # Conversion instructions: # 1. All images from both training and validation set of the RibFrac dataset need to be downloaded from https://ribfrac.grand-challenge.org/dataset/ into a new folder named RibFrac # 2. The RibSeg masks need to be downloaded from https://zenodo.org/record/5336592 into a new folder named RibSeg # 3. Follow unpacking instruction for the RibFrac dataset as in Task154_RibFrac # 4. Unzip RibSeg_490_nii.zip from the RibSeg dataset and rename the folder labelsTr ribfrac_load_path = "/home/k539i/Documents/datasets/original/RibFrac/" ribseg_load_path = "/home/k539i/Documents/datasets/original/RibSeg/" dataset_save_path = "/home/k539i/Documents/datasets/preprocessed/Task156_RibSeg/" max_imagesTr_id = 500 pool = mp.Pool(processes=20) preprocess_dataset(ribfrac_load_path, ribseg_load_path, dataset_save_path, pool) print("Still saving images in background...") pool.close() pool.join() print("All tasks finished.") generate_dataset_json(join(dataset_save_path, 'dataset.json'), join(dataset_save_path, "imagesTr"), None, ('CT',), {0: 'bg', 1: 'rib'}, "Task156_RibSeg")