ho11laqe's picture
init
ecf08bc
raw history blame
No virus
10.4 kB
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import SimpleITK as sitk
import numpy as np
import shutil
from batchgenerators.utilities.file_and_folder_operations import *
from multiprocessing import Pool
from collections import OrderedDict
import copy
def create_nonzero_mask(data):
from scipy.ndimage import binary_fill_holes
assert len(data.shape) == 4 or len(data.shape) == 3, "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
nonzero_mask = np.zeros(data.shape[1:], dtype=bool)
for c in range(data.shape[0]):
this_mask = data[c] != 0
nonzero_mask = nonzero_mask | this_mask
nonzero_mask = binary_fill_holes(nonzero_mask)
return nonzero_mask
def get_bbox_from_mask(mask, outside_value=0):
mask_voxel_coords = np.where(mask != outside_value)
minzidx = int(np.min(mask_voxel_coords[0]))
maxzidx = int(np.max(mask_voxel_coords[0])) + 1
minxidx = int(np.min(mask_voxel_coords[1]))
maxxidx = int(np.max(mask_voxel_coords[1])) + 1
minyidx = int(np.min(mask_voxel_coords[2]))
maxyidx = int(np.max(mask_voxel_coords[2])) + 1
return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]]
def crop_to_bbox(image, bbox):
if len(image.shape) == 3:
resizer = (slice(bbox[0][0], bbox[0][1]), slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1]))
return image[resizer]
elif len(image.shape) == 2:
resizer = (slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1]))
return image[resizer]
def get_case_identifier(case):
case_identifier = case[0].split("/")[-1].split(".nii.gz")[0][:-5]
return case_identifier
def get_case_identifier_from_npz(case):
case_identifier = case.split("/")[-1][:-4]
return case_identifier
def load_case_from_list_of_files(data_files, seg_file=None):
assert isinstance(data_files, list) or isinstance(data_files, tuple), "case must be either a list or a tuple"
properties = OrderedDict()
data_itk = [sitk.ReadImage(f) for f in data_files]
properties["original_size_of_raw_data"] = np.array(data_itk[0].GetSize())[[2, 1, 0]]
properties["original_spacing"] = np.array(data_itk[0].GetSpacing())[[2, 1, 0]]
properties["list_of_data_files"] = data_files
properties["seg_file"] = seg_file
properties["itk_origin"] = data_itk[0].GetOrigin()
properties["itk_spacing"] = data_itk[0].GetSpacing()
properties["itk_direction"] = data_itk[0].GetDirection()
data_npy = np.vstack([sitk.GetArrayFromImage(d)[None] for d in data_itk])
if seg_file is not None:
seg_itk = sitk.ReadImage(seg_file)
seg_npy = sitk.GetArrayFromImage(seg_itk)[None].astype(np.float32)
else:
seg_npy = None
return data_npy.astype(np.float32), seg_npy, properties
def crop_to_nonzero(data, seg=None, nonzero_label=0):
"""
:param data:
:param seg:
:param nonzero_label: this will be written into the segmentation map
:return:
"""
nonzero_mask = create_nonzero_mask(data)
bbox = get_bbox_from_mask(nonzero_mask, 0)
cropped_data = []
for c in range(data.shape[0]):
cropped = crop_to_bbox(data[c], bbox)
cropped_data.append(cropped[None])
data = np.vstack(cropped_data)
if not isinstance(seg, type(None)):
if seg.shape[1] == data.shape[1]:
if seg is not None:
cropped_seg = []
for c in range(seg.shape[0]):
cropped = crop_to_bbox(seg[c], bbox)
cropped_seg.append(cropped[None])
seg = np.vstack(cropped_seg)
nonzero_mask = crop_to_bbox(nonzero_mask, bbox)[None]
if seg is not None:
seg[(seg == 0) & (nonzero_mask == 0)] = nonzero_label
else:
nonzero_mask = nonzero_mask.astype(int)
nonzero_mask[nonzero_mask == 0] = nonzero_label
nonzero_mask[nonzero_mask > 0] = 0
seg = nonzero_mask
return data, seg, bbox
elif seg.shape[1] > data.shape[1]:
# not very clean but should work
bbox_for_seg = copy.copy(bbox)
bbox_for_seg[0] = [0, seg.shape[1]]
nonzero_mask_seg = np.array([nonzero_mask[0] for i in range(seg.shape[1])])
if seg is not None:
cropped_seg = []
for c in range(seg.shape[0]):
cropped = crop_to_bbox(seg[c], bbox_for_seg)
cropped_seg.append(cropped[None])
seg = np.vstack(cropped_seg)
# dont understand what happens here\
# all zeros values are set to -1. But why ?
nonzero_mask = crop_to_bbox(nonzero_mask_seg, bbox_for_seg)[None]
if seg is not None:
seg[(seg == 0) & (nonzero_mask == 0)] = nonzero_label
else:
nonzero_mask = nonzero_mask.astype(int)
nonzero_mask[nonzero_mask == 0] = nonzero_label
nonzero_mask[nonzero_mask > 0] = 0
seg = nonzero_mask
return data, seg, bbox
return data, seg, bbox
def get_patient_identifiers_from_cropped_files(folder):
return [i.split("/")[-1][:-4] for i in subfiles(folder, join=True, suffix=".npz")]
class ImageCropper(object):
def __init__(self, num_threads, output_folder=None):
"""
This one finds a mask of nonzero elements (must be nonzero in all modalities) and crops the image to that mask.
In the case of BRaTS and ISLES data this results in a significant reduction in image size
:param num_threads:
:param output_folder: whete to store the cropped data
:param list_of_files:
"""
self.output_folder = output_folder
self.num_threads = num_threads
if self.output_folder is not None:
maybe_mkdir_p(self.output_folder)
@staticmethod
def crop(data, properties, seg=None):
shape_before = data.shape
data, seg, bbox = crop_to_nonzero(data, seg, nonzero_label=0)
shape_after = data.shape
print("before crop:", shape_before, "after crop:", shape_after, "spacing:",
np.array(properties["original_spacing"]), "\n")
properties["crop_bbox"] = bbox
# TODO can only work with <50 segmentation classes in one label
if not isinstance(seg, type(None)):
classes = [np.unique(segx) for segx in seg[0]]
for i,c in enumerate(classes):
classes[i] = c if len(c)<50 else [0]
properties["classes"] = classes
seg[seg < -1] = 0
properties["size_after_cropping"] = data[0].shape
return data, seg, properties
@staticmethod
def crop_from_list_of_files(data_files, seg_file=None):
data, seg, properties = load_case_from_list_of_files(data_files, seg_file)
return ImageCropper.crop(data, properties, seg)
def load_crop_save(self, case, case_identifier, overwrite_existing=False):
try:
print(case_identifier)
if overwrite_existing \
or (not os.path.isfile(os.path.join(self.output_folder, "%s.npz" % case_identifier))
or not os.path.isfile(os.path.join(self.output_folder, "%s.pkl" % case_identifier))):
data, seg, properties = self.crop_from_list_of_files(case[:-1], case[-1])
all_data = np.vstack((data, seg.transpose((1, 0, 2, 3))))
np.savez_compressed(os.path.join(self.output_folder, "%s.npz" % case_identifier), data=all_data)
with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'wb') as f:
pickle.dump(properties, f)
except Exception as e:
print("Exception in", case_identifier, ":")
print(e)
raise e
def get_list_of_cropped_files(self):
return subfiles(self.output_folder, join=True, suffix=".npz")
def get_patient_identifiers_from_cropped_files(self):
return [i.split("/")[-1][:-4] for i in self.get_list_of_cropped_files()]
def run_cropping(self, list_of_files, overwrite_existing=False, output_folder=None):
"""
also copied ground truth nifti segmentation into the preprocessed folder so that we can use them for evaluation
on the cluster
:param list_of_files: list of list of files [[PATIENTID_TIMESTEP_0000.nii.gz], [PATIENTID_TIMESTEP_0000.nii.gz]]
:param overwrite_existing:
:param output_folder:
:return:
"""
if output_folder is not None:
self.output_folder = output_folder
output_folder_gt = os.path.join(self.output_folder, "gt_segmentations")
maybe_mkdir_p(output_folder_gt)
for j, case in enumerate(list_of_files):
if case[-1] is not None:
shutil.copy(case[-1], output_folder_gt)
list_of_args = []
for j, case in enumerate(list_of_files):
case_identifier = get_case_identifier(case)
# What the fuck happens here
list_of_args.append((case, case_identifier, overwrite_existing))
"""
self.load_crop_save(case, case_identifier)
"""
p = Pool(self.num_threads)
p.starmap(self.load_crop_save, list_of_args)
p.close()
p.join()
def load_properties(self, case_identifier):
with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'rb') as f:
properties = pickle.load(f)
return properties
def save_properties(self, case_identifier, properties):
with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'wb') as f:
pickle.dump(properties, f)