# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import SimpleITK as sitk import numpy as np import shutil from batchgenerators.utilities.file_and_folder_operations import * from multiprocessing import Pool from collections import OrderedDict import copy def create_nonzero_mask(data): from scipy.ndimage import binary_fill_holes assert len(data.shape) == 4 or len(data.shape) == 3, "data must have shape (C, X, Y, Z) or shape (C, X, Y)" nonzero_mask = np.zeros(data.shape[1:], dtype=bool) for c in range(data.shape[0]): this_mask = data[c] != 0 nonzero_mask = nonzero_mask | this_mask nonzero_mask = binary_fill_holes(nonzero_mask) return nonzero_mask def get_bbox_from_mask(mask, outside_value=0): mask_voxel_coords = np.where(mask != outside_value) minzidx = int(np.min(mask_voxel_coords[0])) maxzidx = int(np.max(mask_voxel_coords[0])) + 1 minxidx = int(np.min(mask_voxel_coords[1])) maxxidx = int(np.max(mask_voxel_coords[1])) + 1 minyidx = int(np.min(mask_voxel_coords[2])) maxyidx = int(np.max(mask_voxel_coords[2])) + 1 return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]] def crop_to_bbox(image, bbox): if len(image.shape) == 3: resizer = (slice(bbox[0][0], bbox[0][1]), slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1])) return image[resizer] elif len(image.shape) == 2: resizer = (slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1])) return image[resizer] def get_case_identifier(case): case_identifier = case[0].split("/")[-1].split(".nii.gz")[0][:-5] return case_identifier def get_case_identifier_from_npz(case): case_identifier = case.split("/")[-1][:-4] return case_identifier def load_case_from_list_of_files(data_files, seg_file=None): assert isinstance(data_files, list) or isinstance(data_files, tuple), "case must be either a list or a tuple" properties = OrderedDict() data_itk = [sitk.ReadImage(f) for f in data_files] properties["original_size_of_raw_data"] = np.array(data_itk[0].GetSize())[[2, 1, 0]] properties["original_spacing"] = np.array(data_itk[0].GetSpacing())[[2, 1, 0]] properties["list_of_data_files"] = data_files properties["seg_file"] = seg_file properties["itk_origin"] = data_itk[0].GetOrigin() properties["itk_spacing"] = data_itk[0].GetSpacing() properties["itk_direction"] = data_itk[0].GetDirection() data_npy = np.vstack([sitk.GetArrayFromImage(d)[None] for d in data_itk]) if seg_file is not None: seg_itk = sitk.ReadImage(seg_file) seg_npy = sitk.GetArrayFromImage(seg_itk)[None].astype(np.float32) else: seg_npy = None return data_npy.astype(np.float32), seg_npy, properties def crop_to_nonzero(data, seg=None, nonzero_label=0): """ :param data: :param seg: :param nonzero_label: this will be written into the segmentation map :return: """ nonzero_mask = create_nonzero_mask(data) bbox = get_bbox_from_mask(nonzero_mask, 0) cropped_data = [] for c in range(data.shape[0]): cropped = crop_to_bbox(data[c], bbox) cropped_data.append(cropped[None]) data = np.vstack(cropped_data) if not isinstance(seg, type(None)): if seg.shape[1] == data.shape[1]: if seg is not None: cropped_seg = [] for c in range(seg.shape[0]): cropped = crop_to_bbox(seg[c], bbox) cropped_seg.append(cropped[None]) seg = np.vstack(cropped_seg) nonzero_mask = crop_to_bbox(nonzero_mask, bbox)[None] if seg is not None: seg[(seg == 0) & (nonzero_mask == 0)] = nonzero_label else: nonzero_mask = nonzero_mask.astype(int) nonzero_mask[nonzero_mask == 0] = nonzero_label nonzero_mask[nonzero_mask > 0] = 0 seg = nonzero_mask return data, seg, bbox elif seg.shape[1] > data.shape[1]: # not very clean but should work bbox_for_seg = copy.copy(bbox) bbox_for_seg[0] = [0, seg.shape[1]] nonzero_mask_seg = np.array([nonzero_mask[0] for i in range(seg.shape[1])]) if seg is not None: cropped_seg = [] for c in range(seg.shape[0]): cropped = crop_to_bbox(seg[c], bbox_for_seg) cropped_seg.append(cropped[None]) seg = np.vstack(cropped_seg) # dont understand what happens here\ # all zeros values are set to -1. But why ? nonzero_mask = crop_to_bbox(nonzero_mask_seg, bbox_for_seg)[None] if seg is not None: seg[(seg == 0) & (nonzero_mask == 0)] = nonzero_label else: nonzero_mask = nonzero_mask.astype(int) nonzero_mask[nonzero_mask == 0] = nonzero_label nonzero_mask[nonzero_mask > 0] = 0 seg = nonzero_mask return data, seg, bbox return data, seg, bbox def get_patient_identifiers_from_cropped_files(folder): return [i.split("/")[-1][:-4] for i in subfiles(folder, join=True, suffix=".npz")] class ImageCropper(object): def __init__(self, num_threads, output_folder=None): """ This one finds a mask of nonzero elements (must be nonzero in all modalities) and crops the image to that mask. In the case of BRaTS and ISLES data this results in a significant reduction in image size :param num_threads: :param output_folder: whete to store the cropped data :param list_of_files: """ self.output_folder = output_folder self.num_threads = num_threads if self.output_folder is not None: maybe_mkdir_p(self.output_folder) @staticmethod def crop(data, properties, seg=None): shape_before = data.shape data, seg, bbox = crop_to_nonzero(data, seg, nonzero_label=0) shape_after = data.shape print("before crop:", shape_before, "after crop:", shape_after, "spacing:", np.array(properties["original_spacing"]), "\n") properties["crop_bbox"] = bbox # TODO can only work with <50 segmentation classes in one label if not isinstance(seg, type(None)): classes = [np.unique(segx) for segx in seg[0]] for i,c in enumerate(classes): classes[i] = c if len(c)<50 else [0] properties["classes"] = classes seg[seg < -1] = 0 properties["size_after_cropping"] = data[0].shape return data, seg, properties @staticmethod def crop_from_list_of_files(data_files, seg_file=None): data, seg, properties = load_case_from_list_of_files(data_files, seg_file) return ImageCropper.crop(data, properties, seg) def load_crop_save(self, case, case_identifier, overwrite_existing=False): try: print(case_identifier) if overwrite_existing \ or (not os.path.isfile(os.path.join(self.output_folder, "%s.npz" % case_identifier)) or not os.path.isfile(os.path.join(self.output_folder, "%s.pkl" % case_identifier))): data, seg, properties = self.crop_from_list_of_files(case[:-1], case[-1]) all_data = np.vstack((data, seg.transpose((1, 0, 2, 3)))) np.savez_compressed(os.path.join(self.output_folder, "%s.npz" % case_identifier), data=all_data) with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'wb') as f: pickle.dump(properties, f) except Exception as e: print("Exception in", case_identifier, ":") print(e) raise e def get_list_of_cropped_files(self): return subfiles(self.output_folder, join=True, suffix=".npz") def get_patient_identifiers_from_cropped_files(self): return [i.split("/")[-1][:-4] for i in self.get_list_of_cropped_files()] def run_cropping(self, list_of_files, overwrite_existing=False, output_folder=None): """ also copied ground truth nifti segmentation into the preprocessed folder so that we can use them for evaluation on the cluster :param list_of_files: list of list of files [[PATIENTID_TIMESTEP_0000.nii.gz], [PATIENTID_TIMESTEP_0000.nii.gz]] :param overwrite_existing: :param output_folder: :return: """ if output_folder is not None: self.output_folder = output_folder output_folder_gt = os.path.join(self.output_folder, "gt_segmentations") maybe_mkdir_p(output_folder_gt) for j, case in enumerate(list_of_files): if case[-1] is not None: shutil.copy(case[-1], output_folder_gt) list_of_args = [] for j, case in enumerate(list_of_files): case_identifier = get_case_identifier(case) # What the fuck happens here list_of_args.append((case, case_identifier, overwrite_existing)) """ self.load_crop_save(case, case_identifier) """ p = Pool(self.num_threads) p.starmap(self.load_crop_save, list_of_args) p.close() p.join() def load_properties(self, case_identifier): with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'rb') as f: properties = pickle.load(f) return properties def save_properties(self, case_identifier, properties): with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'wb') as f: pickle.dump(properties, f)