|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import SimpleITK as sitk |
|
import numpy as np |
|
import shutil |
|
from batchgenerators.utilities.file_and_folder_operations import * |
|
from multiprocessing import Pool |
|
from collections import OrderedDict |
|
import copy |
|
|
|
|
|
def create_nonzero_mask(data): |
|
from scipy.ndimage import binary_fill_holes |
|
assert len(data.shape) == 4 or len(data.shape) == 3, "data must have shape (C, X, Y, Z) or shape (C, X, Y)" |
|
nonzero_mask = np.zeros(data.shape[1:], dtype=bool) |
|
for c in range(data.shape[0]): |
|
this_mask = data[c] != 0 |
|
nonzero_mask = nonzero_mask | this_mask |
|
nonzero_mask = binary_fill_holes(nonzero_mask) |
|
return nonzero_mask |
|
|
|
|
|
def get_bbox_from_mask(mask, outside_value=0): |
|
mask_voxel_coords = np.where(mask != outside_value) |
|
minzidx = int(np.min(mask_voxel_coords[0])) |
|
maxzidx = int(np.max(mask_voxel_coords[0])) + 1 |
|
minxidx = int(np.min(mask_voxel_coords[1])) |
|
maxxidx = int(np.max(mask_voxel_coords[1])) + 1 |
|
minyidx = int(np.min(mask_voxel_coords[2])) |
|
maxyidx = int(np.max(mask_voxel_coords[2])) + 1 |
|
return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]] |
|
|
|
|
|
def crop_to_bbox(image, bbox): |
|
if len(image.shape) == 3: |
|
resizer = (slice(bbox[0][0], bbox[0][1]), slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1])) |
|
return image[resizer] |
|
elif len(image.shape) == 2: |
|
resizer = (slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1])) |
|
return image[resizer] |
|
|
|
|
|
def get_case_identifier(case): |
|
case_identifier = case[0].split("/")[-1].split(".nii.gz")[0][:-5] |
|
return case_identifier |
|
|
|
|
|
def get_case_identifier_from_npz(case): |
|
case_identifier = case.split("/")[-1][:-4] |
|
return case_identifier |
|
|
|
|
|
def load_case_from_list_of_files(data_files, seg_file=None): |
|
assert isinstance(data_files, list) or isinstance(data_files, tuple), "case must be either a list or a tuple" |
|
properties = OrderedDict() |
|
data_itk = [sitk.ReadImage(f) for f in data_files] |
|
|
|
properties["original_size_of_raw_data"] = np.array(data_itk[0].GetSize())[[2, 1, 0]] |
|
properties["original_spacing"] = np.array(data_itk[0].GetSpacing())[[2, 1, 0]] |
|
properties["list_of_data_files"] = data_files |
|
properties["seg_file"] = seg_file |
|
|
|
properties["itk_origin"] = data_itk[0].GetOrigin() |
|
properties["itk_spacing"] = data_itk[0].GetSpacing() |
|
properties["itk_direction"] = data_itk[0].GetDirection() |
|
|
|
data_npy = np.vstack([sitk.GetArrayFromImage(d)[None] for d in data_itk]) |
|
if seg_file is not None: |
|
|
|
seg_itk = sitk.ReadImage(seg_file) |
|
seg_npy = sitk.GetArrayFromImage(seg_itk)[None].astype(np.float32) |
|
else: |
|
seg_npy = None |
|
return data_npy.astype(np.float32), seg_npy, properties |
|
|
|
|
|
def crop_to_nonzero(data, seg=None, nonzero_label=0): |
|
""" |
|
|
|
:param data: |
|
:param seg: |
|
:param nonzero_label: this will be written into the segmentation map |
|
:return: |
|
""" |
|
|
|
nonzero_mask = create_nonzero_mask(data) |
|
bbox = get_bbox_from_mask(nonzero_mask, 0) |
|
|
|
cropped_data = [] |
|
for c in range(data.shape[0]): |
|
cropped = crop_to_bbox(data[c], bbox) |
|
cropped_data.append(cropped[None]) |
|
data = np.vstack(cropped_data) |
|
|
|
|
|
if not isinstance(seg, type(None)): |
|
if seg.shape[1] == data.shape[1]: |
|
if seg is not None: |
|
cropped_seg = [] |
|
for c in range(seg.shape[0]): |
|
cropped = crop_to_bbox(seg[c], bbox) |
|
cropped_seg.append(cropped[None]) |
|
seg = np.vstack(cropped_seg) |
|
|
|
nonzero_mask = crop_to_bbox(nonzero_mask, bbox)[None] |
|
if seg is not None: |
|
seg[(seg == 0) & (nonzero_mask == 0)] = nonzero_label |
|
else: |
|
nonzero_mask = nonzero_mask.astype(int) |
|
nonzero_mask[nonzero_mask == 0] = nonzero_label |
|
nonzero_mask[nonzero_mask > 0] = 0 |
|
seg = nonzero_mask |
|
return data, seg, bbox |
|
|
|
elif seg.shape[1] > data.shape[1]: |
|
|
|
bbox_for_seg = copy.copy(bbox) |
|
bbox_for_seg[0] = [0, seg.shape[1]] |
|
|
|
nonzero_mask_seg = np.array([nonzero_mask[0] for i in range(seg.shape[1])]) |
|
if seg is not None: |
|
cropped_seg = [] |
|
for c in range(seg.shape[0]): |
|
cropped = crop_to_bbox(seg[c], bbox_for_seg) |
|
cropped_seg.append(cropped[None]) |
|
seg = np.vstack(cropped_seg) |
|
|
|
|
|
|
|
nonzero_mask = crop_to_bbox(nonzero_mask_seg, bbox_for_seg)[None] |
|
|
|
if seg is not None: |
|
seg[(seg == 0) & (nonzero_mask == 0)] = nonzero_label |
|
else: |
|
nonzero_mask = nonzero_mask.astype(int) |
|
nonzero_mask[nonzero_mask == 0] = nonzero_label |
|
nonzero_mask[nonzero_mask > 0] = 0 |
|
seg = nonzero_mask |
|
return data, seg, bbox |
|
return data, seg, bbox |
|
|
|
|
|
def get_patient_identifiers_from_cropped_files(folder): |
|
return [i.split("/")[-1][:-4] for i in subfiles(folder, join=True, suffix=".npz")] |
|
|
|
|
|
class ImageCropper(object): |
|
def __init__(self, num_threads, output_folder=None): |
|
""" |
|
This one finds a mask of nonzero elements (must be nonzero in all modalities) and crops the image to that mask. |
|
In the case of BRaTS and ISLES data this results in a significant reduction in image size |
|
:param num_threads: |
|
:param output_folder: whete to store the cropped data |
|
:param list_of_files: |
|
""" |
|
self.output_folder = output_folder |
|
self.num_threads = num_threads |
|
|
|
if self.output_folder is not None: |
|
maybe_mkdir_p(self.output_folder) |
|
|
|
@staticmethod |
|
def crop(data, properties, seg=None): |
|
shape_before = data.shape |
|
data, seg, bbox = crop_to_nonzero(data, seg, nonzero_label=0) |
|
shape_after = data.shape |
|
print("before crop:", shape_before, "after crop:", shape_after, "spacing:", |
|
np.array(properties["original_spacing"]), "\n") |
|
|
|
properties["crop_bbox"] = bbox |
|
|
|
if not isinstance(seg, type(None)): |
|
classes = [np.unique(segx) for segx in seg[0]] |
|
for i,c in enumerate(classes): |
|
classes[i] = c if len(c)<50 else [0] |
|
properties["classes"] = classes |
|
seg[seg < -1] = 0 |
|
properties["size_after_cropping"] = data[0].shape |
|
return data, seg, properties |
|
|
|
@staticmethod |
|
def crop_from_list_of_files(data_files, seg_file=None): |
|
data, seg, properties = load_case_from_list_of_files(data_files, seg_file) |
|
return ImageCropper.crop(data, properties, seg) |
|
|
|
def load_crop_save(self, case, case_identifier, overwrite_existing=False): |
|
try: |
|
print(case_identifier) |
|
if overwrite_existing \ |
|
or (not os.path.isfile(os.path.join(self.output_folder, "%s.npz" % case_identifier)) |
|
or not os.path.isfile(os.path.join(self.output_folder, "%s.pkl" % case_identifier))): |
|
data, seg, properties = self.crop_from_list_of_files(case[:-1], case[-1]) |
|
|
|
all_data = np.vstack((data, seg.transpose((1, 0, 2, 3)))) |
|
np.savez_compressed(os.path.join(self.output_folder, "%s.npz" % case_identifier), data=all_data) |
|
with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'wb') as f: |
|
pickle.dump(properties, f) |
|
except Exception as e: |
|
print("Exception in", case_identifier, ":") |
|
print(e) |
|
raise e |
|
|
|
def get_list_of_cropped_files(self): |
|
return subfiles(self.output_folder, join=True, suffix=".npz") |
|
|
|
def get_patient_identifiers_from_cropped_files(self): |
|
return [i.split("/")[-1][:-4] for i in self.get_list_of_cropped_files()] |
|
|
|
def run_cropping(self, list_of_files, overwrite_existing=False, output_folder=None): |
|
""" |
|
also copied ground truth nifti segmentation into the preprocessed folder so that we can use them for evaluation |
|
on the cluster |
|
:param list_of_files: list of list of files [[PATIENTID_TIMESTEP_0000.nii.gz], [PATIENTID_TIMESTEP_0000.nii.gz]] |
|
:param overwrite_existing: |
|
:param output_folder: |
|
:return: |
|
""" |
|
if output_folder is not None: |
|
self.output_folder = output_folder |
|
|
|
output_folder_gt = os.path.join(self.output_folder, "gt_segmentations") |
|
maybe_mkdir_p(output_folder_gt) |
|
for j, case in enumerate(list_of_files): |
|
if case[-1] is not None: |
|
shutil.copy(case[-1], output_folder_gt) |
|
|
|
list_of_args = [] |
|
for j, case in enumerate(list_of_files): |
|
case_identifier = get_case_identifier(case) |
|
|
|
list_of_args.append((case, case_identifier, overwrite_existing)) |
|
|
|
""" |
|
self.load_crop_save(case, case_identifier) |
|
""" |
|
p = Pool(self.num_threads) |
|
p.starmap(self.load_crop_save, list_of_args) |
|
p.close() |
|
p.join() |
|
|
|
|
|
def load_properties(self, case_identifier): |
|
with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'rb') as f: |
|
properties = pickle.load(f) |
|
return properties |
|
|
|
def save_properties(self, case_identifier, properties): |
|
with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'wb') as f: |
|
pickle.dump(properties, f) |
|
|