|
from natsort import natsorted
|
|
import numpy as np
|
|
from pathlib import Path
|
|
import os
|
|
from os.path import join
|
|
from nnunet.dataset_conversion.utils import generate_dataset_json
|
|
import SimpleITK as sitk
|
|
import gc
|
|
import multiprocessing as mp
|
|
from functools import partial
|
|
|
|
|
|
def preprocess_dataset(ribfrac_load_path, ribseg_load_path, dataset_save_path, pool):
|
|
mask_load_path = join(ribseg_load_path, "labelsTr")
|
|
|
|
train_image_save_path = join(dataset_save_path, "imagesTr")
|
|
train_mask_save_path = join(dataset_save_path, "labelsTr")
|
|
test_image_save_path = join(dataset_save_path, "imagesTs")
|
|
test_labels_save_path = join(dataset_save_path, "labelsTs")
|
|
Path(train_image_save_path).mkdir(parents=True, exist_ok=True)
|
|
Path(train_mask_save_path).mkdir(parents=True, exist_ok=True)
|
|
Path(test_image_save_path).mkdir(parents=True, exist_ok=True)
|
|
Path(test_labels_save_path).mkdir(parents=True, exist_ok=True)
|
|
|
|
mask_filenames = load_filenames(mask_load_path)
|
|
pool.map(partial(preprocess_single, image_load_path=ribfrac_load_path), mask_filenames)
|
|
|
|
|
|
def preprocess_single(filename, image_load_path):
|
|
name = os.path.basename(filename)
|
|
if "-cl.nii.gz" in name:
|
|
return
|
|
id = int(name.split("-")[0][7:])
|
|
image_set = "imagesTr"
|
|
mask_set = "labelsTr"
|
|
if id > 500:
|
|
image_set = "imagesTs"
|
|
mask_set = "labelsTs"
|
|
image, _, _, _ = load_image(join(image_load_path, image_set, "RibFrac{}-image.nii.gz".format(id)), return_meta=True, is_seg=False)
|
|
mask, spacing, _, _ = load_image(filename, return_meta=True, is_seg=True)
|
|
save_image(join(dataset_save_path, image_set, "RibSeg_" + str(id).zfill(4) + "_0000.nii.gz"), image, spacing=spacing, is_seg=False)
|
|
save_image(join(dataset_save_path, mask_set, "RibSeg_" + str(id).zfill(4) + ".nii.gz"), mask, spacing=spacing, is_seg=True)
|
|
|
|
|
|
def load_filenames(img_dir, extensions=None):
|
|
_img_dir = fix_path(img_dir)
|
|
img_filenames = []
|
|
|
|
for file in os.listdir(_img_dir):
|
|
if extensions is None or file.endswith(extensions):
|
|
img_filenames.append(_img_dir + file)
|
|
img_filenames = np.asarray(img_filenames)
|
|
img_filenames = natsorted(img_filenames)
|
|
|
|
return img_filenames
|
|
|
|
|
|
def fix_path(path):
|
|
if path[-1] != "/":
|
|
path += "/"
|
|
return path
|
|
|
|
|
|
def load_image(filepath, return_meta=False, is_seg=False):
|
|
image = sitk.ReadImage(filepath)
|
|
image_np = sitk.GetArrayFromImage(image)
|
|
|
|
if is_seg:
|
|
image_np = np.rint(image_np)
|
|
image_np = image_np.astype(np.int8)
|
|
|
|
if not return_meta:
|
|
return image_np
|
|
else:
|
|
spacing = image.GetSpacing()
|
|
keys = image.GetMetaDataKeys()
|
|
header = {key:image.GetMetaData(key) for key in keys}
|
|
affine = None
|
|
return image_np, spacing, affine, header
|
|
|
|
|
|
def save_image(filename, image, spacing=None, affine=None, header=None, is_seg=False, mp_pool=None, free_mem=False):
|
|
if is_seg:
|
|
image = np.rint(image)
|
|
image = image.astype(np.int8)
|
|
|
|
image = sitk.GetImageFromArray(image)
|
|
|
|
if header is not None:
|
|
[image.SetMetaData(key, header[key]) for key in header.keys()]
|
|
|
|
if spacing is not None:
|
|
image.SetSpacing(spacing)
|
|
|
|
if affine is not None:
|
|
pass
|
|
|
|
if mp_pool is None:
|
|
sitk.WriteImage(image, filename)
|
|
if free_mem:
|
|
del image
|
|
gc.collect()
|
|
else:
|
|
mp_pool.apply_async(_save, args=(filename, image, free_mem,))
|
|
if free_mem:
|
|
del image
|
|
gc.collect()
|
|
|
|
|
|
def _save(filename, image, free_mem):
|
|
sitk.WriteImage(image, filename)
|
|
if free_mem:
|
|
del image
|
|
gc.collect()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ribfrac_load_path = "/home/k539i/Documents/datasets/original/RibFrac/"
|
|
ribseg_load_path = "/home/k539i/Documents/datasets/original/RibSeg/"
|
|
dataset_save_path = "/home/k539i/Documents/datasets/preprocessed/Task156_RibSeg/"
|
|
|
|
max_imagesTr_id = 500
|
|
|
|
pool = mp.Pool(processes=20)
|
|
|
|
preprocess_dataset(ribfrac_load_path, ribseg_load_path, dataset_save_path, pool)
|
|
|
|
print("Still saving images in background...")
|
|
pool.close()
|
|
pool.join()
|
|
print("All tasks finished.")
|
|
|
|
generate_dataset_json(join(dataset_save_path, 'dataset.json'), join(dataset_save_path, "imagesTr"), None, ('CT',), {0: 'bg', 1: 'rib'}, "Task156_RibSeg")
|
|
|