Chris Xiao
upload files
c642393
from natsort import natsorted
import numpy as np
from pathlib import Path
import os
from os.path import join
from nnunet.dataset_conversion.utils import generate_dataset_json
import SimpleITK as sitk
import gc
import multiprocessing as mp
from functools import partial
def preprocess_dataset(ribfrac_load_path, ribseg_load_path, dataset_save_path, pool):
mask_load_path = join(ribseg_load_path, "labelsTr")
train_image_save_path = join(dataset_save_path, "imagesTr")
train_mask_save_path = join(dataset_save_path, "labelsTr")
test_image_save_path = join(dataset_save_path, "imagesTs")
test_labels_save_path = join(dataset_save_path, "labelsTs")
Path(train_image_save_path).mkdir(parents=True, exist_ok=True)
Path(train_mask_save_path).mkdir(parents=True, exist_ok=True)
Path(test_image_save_path).mkdir(parents=True, exist_ok=True)
Path(test_labels_save_path).mkdir(parents=True, exist_ok=True)
mask_filenames = load_filenames(mask_load_path)
pool.map(partial(preprocess_single, image_load_path=ribfrac_load_path), mask_filenames)
def preprocess_single(filename, image_load_path):
name = os.path.basename(filename)
if "-cl.nii.gz" in name:
return
id = int(name.split("-")[0][7:])
image_set = "imagesTr"
mask_set = "labelsTr"
if id > 500:
image_set = "imagesTs"
mask_set = "labelsTs"
image, _, _, _ = load_image(join(image_load_path, image_set, "RibFrac{}-image.nii.gz".format(id)), return_meta=True, is_seg=False)
mask, spacing, _, _ = load_image(filename, return_meta=True, is_seg=True)
save_image(join(dataset_save_path, image_set, "RibSeg_" + str(id).zfill(4) + "_0000.nii.gz"), image, spacing=spacing, is_seg=False)
save_image(join(dataset_save_path, mask_set, "RibSeg_" + str(id).zfill(4) + ".nii.gz"), mask, spacing=spacing, is_seg=True)
def load_filenames(img_dir, extensions=None):
_img_dir = fix_path(img_dir)
img_filenames = []
for file in os.listdir(_img_dir):
if extensions is None or file.endswith(extensions):
img_filenames.append(_img_dir + file)
img_filenames = np.asarray(img_filenames)
img_filenames = natsorted(img_filenames)
return img_filenames
def fix_path(path):
if path[-1] != "/":
path += "/"
return path
def load_image(filepath, return_meta=False, is_seg=False):
image = sitk.ReadImage(filepath)
image_np = sitk.GetArrayFromImage(image)
if is_seg:
image_np = np.rint(image_np)
image_np = image_np.astype(np.int8) # In special cases segmentations can contain negative labels, so no np.uint8
if not return_meta:
return image_np
else:
spacing = image.GetSpacing()
keys = image.GetMetaDataKeys()
header = {key:image.GetMetaData(key) for key in keys}
affine = None # How do I get the affine transform with SimpleITK? With NiBabel it is just image.affine
return image_np, spacing, affine, header
def save_image(filename, image, spacing=None, affine=None, header=None, is_seg=False, mp_pool=None, free_mem=False):
if is_seg:
image = np.rint(image)
image = image.astype(np.int8) # In special cases segmentations can contain negative labels, so no np.uint8
image = sitk.GetImageFromArray(image)
if header is not None:
[image.SetMetaData(key, header[key]) for key in header.keys()]
if spacing is not None:
image.SetSpacing(spacing)
if affine is not None:
pass # How do I set the affine transform with SimpleITK? With NiBabel it is just nib.Nifti1Image(img, affine=affine, header=header)
if mp_pool is None:
sitk.WriteImage(image, filename)
if free_mem:
del image
gc.collect()
else:
mp_pool.apply_async(_save, args=(filename, image, free_mem,))
if free_mem:
del image
gc.collect()
def _save(filename, image, free_mem):
sitk.WriteImage(image, filename)
if free_mem:
del image
gc.collect()
if __name__ == "__main__":
# Note: Due to a bug in SimpleITK 2.1.x a version of SimpleITK < 2.1.0 is required for loading images. Further, we can't copy the images and masks, but have to load them and resample both to the same spacing.
# Conversion instructions:
# 1. All images from both training and validation set of the RibFrac dataset need to be downloaded from https://ribfrac.grand-challenge.org/dataset/ into a new folder named RibFrac
# 2. The RibSeg masks need to be downloaded from https://zenodo.org/record/5336592 into a new folder named RibSeg
# 3. Follow unpacking instruction for the RibFrac dataset as in Task154_RibFrac
# 4. Unzip RibSeg_490_nii.zip from the RibSeg dataset and rename the folder labelsTr
ribfrac_load_path = "/home/k539i/Documents/datasets/original/RibFrac/"
ribseg_load_path = "/home/k539i/Documents/datasets/original/RibSeg/"
dataset_save_path = "/home/k539i/Documents/datasets/preprocessed/Task156_RibSeg/"
max_imagesTr_id = 500
pool = mp.Pool(processes=20)
preprocess_dataset(ribfrac_load_path, ribseg_load_path, dataset_save_path, pool)
print("Still saving images in background...")
pool.close()
pool.join()
print("All tasks finished.")
generate_dataset_json(join(dataset_save_path, 'dataset.json'), join(dataset_save_path, "imagesTr"), None, ('CT',), {0: 'bg', 1: 'rib'}, "Task156_RibSeg")