import os from copy import deepcopy from typing import Union, List import numpy as np import torch from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice from batchgenerators.utilities.file_and_folder_operations import load_json, isfile, save_pickle from nnunetv2.configuration import default_num_processes from nnunetv2.utilities.label_handling.label_handling import LabelManager from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager import SimpleITK as sitk def convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits: Union[torch.Tensor, np.ndarray], plans_manager: PlansManager, configuration_manager: ConfigurationManager, label_manager: LabelManager, properties_dict: dict, return_probabilities: bool = False, num_threads_torch: int = default_num_processes): old_threads = torch.get_num_threads() torch.set_num_threads(num_threads_torch) # resample to original shape current_spacing = configuration_manager.spacing if \ len(configuration_manager.spacing) == \ len(properties_dict['shape_after_cropping_and_before_resampling']) else \ [properties_dict['spacing'][0], *configuration_manager.spacing] predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits, properties_dict['shape_after_cropping_and_before_resampling'], current_spacing, properties_dict['spacing']) # return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because # apply_inference_nonlin will convert to torch predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits) del predicted_logits segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities) # segmentation may be torch.Tensor but we continue with numpy if isinstance(segmentation, torch.Tensor): segmentation = segmentation.cpu().numpy() # put segmentation in bbox (revert cropping) segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'], dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16) slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping']) segmentation_reverted_cropping[slicer] = segmentation del segmentation # revert transpose segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(plans_manager.transpose_backward) if return_probabilities: # revert cropping predicted_probabilities = label_manager.revert_cropping_on_probabilities(predicted_probabilities, properties_dict[ 'bbox_used_for_cropping'], properties_dict[ 'shape_before_cropping']) predicted_probabilities = predicted_probabilities.cpu().numpy() # revert transpose predicted_probabilities = predicted_probabilities.transpose([0] + [i + 1 for i in plans_manager.transpose_backward]) torch.set_num_threads(old_threads) return segmentation_reverted_cropping, predicted_probabilities else: torch.set_num_threads(old_threads) return segmentation_reverted_cropping import torch import numpy as np def convert_predicted_image_to_original_shape(predicted_image: Union[torch.Tensor, np.ndarray], #arthur plans_manager: PlansManager, configuration_manager: ConfigurationManager, properties_dict: dict, num_threads_torch: int = default_num_processes): old_threads = torch.get_num_threads() torch.set_num_threads(num_threads_torch) # Resample to original shape # Assuming configuration_manager has a resampling function for images current_spacing = configuration_manager.spacing if \ len(configuration_manager.spacing) == \ len(properties_dict['shape_after_cropping_and_before_resampling']) else \ [properties_dict['spacing'][0], *configuration_manager.spacing] predicted_resampled = configuration_manager.resampling_fn_data(predicted_image, properties_dict['shape_after_cropping_and_before_resampling'], current_spacing, properties_dict['spacing']) # Ensure output is in numpy format for further processing if not already # print(predicted_resampled.shape) if isinstance(predicted_resampled, torch.Tensor): predicted_resampled = predicted_resampled.cpu().numpy() predicted_resampled = predicted_resampled[0] #arthur : hardcoded first channel # sitk.WriteImage(sitk.GetImageFromArray(predicted_resampled.astype(np.float32)), "test_tanh.nii.gz") print(predicted_resampled.shape, np.min(predicted_resampled), np.max(predicted_resampled), predicted_resampled.dtype) # Put the image back into its original bounding box (revert cropping) original_shape_image = np.zeros(properties_dict['shape_before_cropping'], dtype=predicted_resampled.dtype) slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping']) original_shape_image[slicer] = predicted_resampled del predicted_resampled # Revert transpose original_shape_image = original_shape_image.transpose(plans_manager.transpose_backward) torch.set_num_threads(old_threads) return original_shape_image def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict, configuration_manager: ConfigurationManager, plans_manager: PlansManager, dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str, save_probabilities: bool = False): # if isinstance(predicted_array_or_file, str): # tmp = deepcopy(predicted_array_or_file) # if predicted_array_or_file.endswith('.npy'): # predicted_array_or_file = np.load(predicted_array_or_file) # elif predicted_array_or_file.endswith('.npz'): # predicted_array_or_file = np.load(predicted_array_or_file)['softmax'] # os.remove(tmp) if isinstance(dataset_json_dict_or_file, str): dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file) # ret = convert_predicted_logits_to_segmentation_with_correct_shape( #arthur removed # predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict, # return_probabilities=save_probabilities # ) ret = convert_predicted_image_to_original_shape( predicted_array_or_file, plans_manager, configuration_manager, properties_dict ) del predicted_array_or_file # save if save_probabilities: segmentation_final, probabilities_final = ret np.savez_compressed(output_file_truncated + '.npz', probabilities=probabilities_final) save_pickle(properties_dict, output_file_truncated + '.pkl') del probabilities_final, ret else: segmentation_final = ret del ret rw = plans_manager.image_reader_writer_class() rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'], properties_dict) return segmentation_final, properties_dict # print(f"[EXP][after export_prediction_from_logits] seg shape={seg.shape}, " # f"dtype={seg.dtype}, unique={np.unique(seg)}") def resample_and_save(predicted: Union[torch.Tensor, np.ndarray], target_shape: List[int], output_file: str, plans_manager: PlansManager, configuration_manager: ConfigurationManager, properties_dict: dict, dataset_json_dict_or_file: Union[dict, str], num_threads_torch: int = default_num_processes) \ -> None: # # needed for cascade # if isinstance(predicted, str): # assert isfile(predicted), "If isinstance(segmentation_softmax, str) then " \ # "isfile(segmentation_softmax) must be True" # del_file = deepcopy(predicted) # predicted = np.load(predicted) # os.remove(del_file) old_threads = torch.get_num_threads() torch.set_num_threads(num_threads_torch) if isinstance(dataset_json_dict_or_file, str): dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) # resample to original shape current_spacing = configuration_manager.spacing if \ len(configuration_manager.spacing) == len(properties_dict['shape_after_cropping_and_before_resampling']) else \ [properties_dict['spacing'][0], *configuration_manager.spacing] target_spacing = configuration_manager.spacing if len(configuration_manager.spacing) == \ len(properties_dict['shape_after_cropping_and_before_resampling']) else \ [properties_dict['spacing'][0], *configuration_manager.spacing] predicted_array_or_file = configuration_manager.resampling_fn_probabilities(predicted, target_shape, current_spacing, target_spacing) # create segmentation (argmax, regions, etc) label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file) segmentation = label_manager.convert_logits_to_segmentation(predicted_array_or_file) # segmentation may be torch.Tensor but we continue with numpy if isinstance(segmentation, torch.Tensor): segmentation = segmentation.cpu().numpy() np.savez_compressed(output_file, seg=segmentation.astype(np.uint8)) torch.set_num_threads(old_threads)