Spaces:
Sleeping
Sleeping
| import os | |
| from copy import deepcopy | |
| from typing import Union, List | |
| import numpy as np | |
| import torch | |
| from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice | |
| from batchgenerators.utilities.file_and_folder_operations import load_json, isfile, save_pickle | |
| from nnunetv2.configuration import default_num_processes | |
| from nnunetv2.utilities.label_handling.label_handling import LabelManager | |
| from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager | |
| import SimpleITK as sitk | |
| def convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits: Union[torch.Tensor, np.ndarray], | |
| plans_manager: PlansManager, | |
| configuration_manager: ConfigurationManager, | |
| label_manager: LabelManager, | |
| properties_dict: dict, | |
| return_probabilities: bool = False, | |
| num_threads_torch: int = default_num_processes): | |
| old_threads = torch.get_num_threads() | |
| torch.set_num_threads(num_threads_torch) | |
| # resample to original shape | |
| current_spacing = configuration_manager.spacing if \ | |
| len(configuration_manager.spacing) == \ | |
| len(properties_dict['shape_after_cropping_and_before_resampling']) else \ | |
| [properties_dict['spacing'][0], *configuration_manager.spacing] | |
| predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits, | |
| properties_dict['shape_after_cropping_and_before_resampling'], | |
| current_spacing, | |
| properties_dict['spacing']) | |
| # return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because | |
| # apply_inference_nonlin will convert to torch | |
| predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits) | |
| del predicted_logits | |
| segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities) | |
| # segmentation may be torch.Tensor but we continue with numpy | |
| if isinstance(segmentation, torch.Tensor): | |
| segmentation = segmentation.cpu().numpy() | |
| # put segmentation in bbox (revert cropping) | |
| segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'], | |
| dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16) | |
| slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping']) | |
| segmentation_reverted_cropping[slicer] = segmentation | |
| del segmentation | |
| # revert transpose | |
| segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(plans_manager.transpose_backward) | |
| if return_probabilities: | |
| # revert cropping | |
| predicted_probabilities = label_manager.revert_cropping_on_probabilities(predicted_probabilities, | |
| properties_dict[ | |
| 'bbox_used_for_cropping'], | |
| properties_dict[ | |
| 'shape_before_cropping']) | |
| predicted_probabilities = predicted_probabilities.cpu().numpy() | |
| # revert transpose | |
| predicted_probabilities = predicted_probabilities.transpose([0] + [i + 1 for i in | |
| plans_manager.transpose_backward]) | |
| torch.set_num_threads(old_threads) | |
| return segmentation_reverted_cropping, predicted_probabilities | |
| else: | |
| torch.set_num_threads(old_threads) | |
| return segmentation_reverted_cropping | |
| import torch | |
| import numpy as np | |
| def convert_predicted_image_to_original_shape(predicted_image: Union[torch.Tensor, np.ndarray], #arthur | |
| plans_manager: PlansManager, | |
| configuration_manager: ConfigurationManager, | |
| properties_dict: dict, | |
| num_threads_torch: int = default_num_processes): | |
| old_threads = torch.get_num_threads() | |
| torch.set_num_threads(num_threads_torch) | |
| # Resample to original shape | |
| # Assuming configuration_manager has a resampling function for images | |
| current_spacing = configuration_manager.spacing if \ | |
| len(configuration_manager.spacing) == \ | |
| len(properties_dict['shape_after_cropping_and_before_resampling']) else \ | |
| [properties_dict['spacing'][0], *configuration_manager.spacing] | |
| predicted_resampled = configuration_manager.resampling_fn_data(predicted_image, | |
| properties_dict['shape_after_cropping_and_before_resampling'], | |
| current_spacing, | |
| properties_dict['spacing']) | |
| # Ensure output is in numpy format for further processing if not already | |
| # print(predicted_resampled.shape) | |
| if isinstance(predicted_resampled, torch.Tensor): | |
| predicted_resampled = predicted_resampled.cpu().numpy() | |
| predicted_resampled = predicted_resampled[0] #arthur : hardcoded first channel | |
| # sitk.WriteImage(sitk.GetImageFromArray(predicted_resampled.astype(np.float32)), "test_tanh.nii.gz") | |
| print(predicted_resampled.shape, np.min(predicted_resampled), np.max(predicted_resampled), predicted_resampled.dtype) | |
| # Put the image back into its original bounding box (revert cropping) | |
| original_shape_image = np.zeros(properties_dict['shape_before_cropping'], | |
| dtype=predicted_resampled.dtype) | |
| slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping']) | |
| original_shape_image[slicer] = predicted_resampled | |
| del predicted_resampled | |
| # Revert transpose | |
| original_shape_image = original_shape_image.transpose(plans_manager.transpose_backward) | |
| torch.set_num_threads(old_threads) | |
| return original_shape_image | |
| def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict, | |
| configuration_manager: ConfigurationManager, | |
| plans_manager: PlansManager, | |
| dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str, | |
| save_probabilities: bool = False): | |
| # if isinstance(predicted_array_or_file, str): | |
| # tmp = deepcopy(predicted_array_or_file) | |
| # if predicted_array_or_file.endswith('.npy'): | |
| # predicted_array_or_file = np.load(predicted_array_or_file) | |
| # elif predicted_array_or_file.endswith('.npz'): | |
| # predicted_array_or_file = np.load(predicted_array_or_file)['softmax'] | |
| # os.remove(tmp) | |
| if isinstance(dataset_json_dict_or_file, str): | |
| dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) | |
| label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file) | |
| # ret = convert_predicted_logits_to_segmentation_with_correct_shape( #arthur removed | |
| # predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict, | |
| # return_probabilities=save_probabilities | |
| # ) | |
| ret = convert_predicted_image_to_original_shape( | |
| predicted_array_or_file, plans_manager, configuration_manager, properties_dict | |
| ) | |
| del predicted_array_or_file | |
| # save | |
| if save_probabilities: | |
| segmentation_final, probabilities_final = ret | |
| np.savez_compressed(output_file_truncated + '.npz', probabilities=probabilities_final) | |
| save_pickle(properties_dict, output_file_truncated + '.pkl') | |
| del probabilities_final, ret | |
| else: | |
| segmentation_final = ret | |
| del ret | |
| rw = plans_manager.image_reader_writer_class() | |
| rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'], | |
| properties_dict) | |
| return segmentation_final, properties_dict | |
| # print(f"[EXP][after export_prediction_from_logits] seg shape={seg.shape}, " | |
| # f"dtype={seg.dtype}, unique={np.unique(seg)}") | |
| def resample_and_save(predicted: Union[torch.Tensor, np.ndarray], target_shape: List[int], output_file: str, | |
| plans_manager: PlansManager, configuration_manager: ConfigurationManager, properties_dict: dict, | |
| dataset_json_dict_or_file: Union[dict, str], num_threads_torch: int = default_num_processes) \ | |
| -> None: | |
| # # needed for cascade | |
| # if isinstance(predicted, str): | |
| # assert isfile(predicted), "If isinstance(segmentation_softmax, str) then " \ | |
| # "isfile(segmentation_softmax) must be True" | |
| # del_file = deepcopy(predicted) | |
| # predicted = np.load(predicted) | |
| # os.remove(del_file) | |
| old_threads = torch.get_num_threads() | |
| torch.set_num_threads(num_threads_torch) | |
| if isinstance(dataset_json_dict_or_file, str): | |
| dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) | |
| # resample to original shape | |
| current_spacing = configuration_manager.spacing if \ | |
| len(configuration_manager.spacing) == len(properties_dict['shape_after_cropping_and_before_resampling']) else \ | |
| [properties_dict['spacing'][0], *configuration_manager.spacing] | |
| target_spacing = configuration_manager.spacing if len(configuration_manager.spacing) == \ | |
| len(properties_dict['shape_after_cropping_and_before_resampling']) else \ | |
| [properties_dict['spacing'][0], *configuration_manager.spacing] | |
| predicted_array_or_file = configuration_manager.resampling_fn_probabilities(predicted, | |
| target_shape, | |
| current_spacing, | |
| target_spacing) | |
| # create segmentation (argmax, regions, etc) | |
| label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file) | |
| segmentation = label_manager.convert_logits_to_segmentation(predicted_array_or_file) | |
| # segmentation may be torch.Tensor but we continue with numpy | |
| if isinstance(segmentation, torch.Tensor): | |
| segmentation = segmentation.cpu().numpy() | |
| np.savez_compressed(output_file, seg=segmentation.astype(np.uint8)) | |
| torch.set_num_threads(old_threads) | |